简体   繁体   English

PyTorch 中的批处理余弦相似度

[英]Batched Cosine Similarity in PyTorch

Inputs:输入:

  • Tensor a of shape [batch_size, n, d]形状为[batch_size, n, d]的张量a
  • Tensor b of shape [batch_size, m, d]形状为[batch_size, m, d]的张量b

Output: Output:

  • Tensor c of shape [batch_size, n, m] where c[i, j, k] is the cosine similarity between a[i, j] and b[i, k]形状为[batch_size, n, m]的张量c其中c[i, j, k]a[i, j]b[i, k]之间的余弦相似度

How to implement this efficiently in PyTorch (preferably without for loops)?如何在 PyTorch 中有效地实现这一点(最好没有for循环)?

try this:尝试这个:

c = torch.cosine_similarity(a.unsqueeze(2), b.unsqueeze(1), dim=-1)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM