繁体   English   中英

Python: TypeError: __init__() takes from 1 to 2 positional arguments but 3 were given

[英]Python : TypeError: __init__() takes from 1 to 2 positional arguments but 3 were given

我遇到了一个关于 class 和 function 的编程问题。看来我不能正确地使用 class。 你能指出这个问题吗? 谢谢你 !

class NTXentLoss(nn.Module):
    def __init__(self, temp=0.5):
        super(NTXentLoss, self).__init__()
        self.temp = temp
    
    def forward(self, zi, zj):
        batch_size = zi.shape[0]
        z_proj = torch.cat((zi, zj), dim=0)
        cos_sim = torch.nn.CosineSimilarity(dim=-1)
        sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0))
        sim_mat_scaled = torch.exp(sim_mat/self.temp)
        r_diag = torch.diag(sim_mat_scaled, batch_size)
        l_diag = torch.diag(sim_mat_scaled, -batch_size)
        pos = torch.cat([r_diag, l_diag])
        diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda()
        logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat))
        loss = logit.mean()
        return loss

        sent_A = l2norm(recov_A, dim=1)
        sent_emb_A = l2norm(imgs_A, dim=1)
        sent_B = l2norm(recov_B, dim=1)
        sent_emb_B = l2norm(imgs_B, dim=1)

G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)

这有什么问题,我只是给了两个位置arguments? 或者

G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)

需要先发起一个NTXentLoss object 才能调用。 例如:

ntx = NTXentLoss()
G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM