![](/img/trans.png)
[英]python super :TypeError: __init__() takes 2 positional arguments but 3 were given
[英]Python : TypeError: __init__() takes from 1 to 2 positional arguments but 3 were given
我遇到了一个关于 class 和 function 的编程问题。看来我不能正确地使用 class。 你能指出这个问题吗? 谢谢你 !
class NTXentLoss(nn.Module):
def __init__(self, temp=0.5):
super(NTXentLoss, self).__init__()
self.temp = temp
def forward(self, zi, zj):
batch_size = zi.shape[0]
z_proj = torch.cat((zi, zj), dim=0)
cos_sim = torch.nn.CosineSimilarity(dim=-1)
sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0))
sim_mat_scaled = torch.exp(sim_mat/self.temp)
r_diag = torch.diag(sim_mat_scaled, batch_size)
l_diag = torch.diag(sim_mat_scaled, -batch_size)
pos = torch.cat([r_diag, l_diag])
diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda()
logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat))
loss = logit.mean()
return loss
sent_A = l2norm(recov_A, dim=1)
sent_emb_A = l2norm(imgs_A, dim=1)
sent_B = l2norm(recov_B, dim=1)
sent_emb_B = l2norm(imgs_B, dim=1)
G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)
这有什么问题,我只是给了两个位置arguments? 或者
G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)
需要先发起一个NTXentLoss object 才能调用。 例如:
ntx = NTXentLoss()
G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.