哪位大佬可以免费赐教一下吗?如果方便的话麻烦给下面的代码加个注释,多谢了!!!


def topk(para, k):
c = torch.zeros(para.size()[0], para.size()[1],dtype = torch.int)
l = int(para.size()[1]/7)
parameter = torch.abs(para)
_, b = torch.topk(parameter[:,:7], k, 1, largest = True)
for i in range(1,l):
_, b1 = torch.topk(parameter[:,i*7:(i+1)*7], k, 1, largest = True)
b1 = b1 + i * 7
b = torch.cat((b,b1),dim=1)
for j in range(c.size()[0]):
c[j, b[j, :]] = 1
return c



def topk(para, k):
c = torch.zeros(para.size()[0], para.size()[1],dtype = torch.int)
l = int(para.size()[1]/7)
parameter = torch.abs(para)
_, b = torch.topk(parameter[:,:7], k, 1, largest = True)
for i in range(1,l):
_, b1 = torch.topk(parameter[:,i*7:(i+1)*7], k, 1, largest = True)
b1 = b1 + i * 7
b = torch.cat((b,b1),dim=1)
for j in range(c.size()[0]):
c[j, b[j, :]] = 1
return c