简体   繁体   English

PyTorch:比较三个张量?

[英]PyTorch: compare three tensors?

I have three boolean mask tensors that I want to create a boolean mask that if the value matches in three tensors then it is 1 , else 0 .我有三个 boolean 掩码张量,我想创建一个 boolean 掩码,如果值在三个张量中匹配,则为1 ,否则0

I tried torch.where(A == B == C, 1, 0) , but it doesn't seem to support such.我试过torch.where(A == B == C, 1, 0) ,但它似乎不支持这样。

The torch.eq operator only supports binary tensor comparisons , hence you need to perform two comparisons: torch.eq运算符仅支持二进制张量比较,因此您需要执行两个比较:

(A==B) & (B==C)

You can use:您可以使用:

((A == B) & (B == C))

If required, you can always convert the boolean tensor to an appropriate type:如果需要,您始终可以将 boolean 张量转换为适当的类型:

((A == B) & (B == C)).to(float)

AFAIK, the tensor is basically a NumPy array bound to the device. AFAIK,张量基本上是绑定到设备的 NumPy 数组。 If not too expensive for your application and you can afford to do it on CPU, you can simply convert it to NumPy and do what you need with the comparison.如果您的应用程序不太昂贵并且您可以负担得起在 CPU 上执行此操作,您可以简单地将其转换为 NumPy 并通过比较做您需要的事情。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM