繁体   English   中英

删除 torch.tensor 中的重复行

[英]Delete duplicated rows in torch.tensor

我有一个形状为(n,m)torch.tensor ,我想删除重复的行(或至少找到它们)。 例如:

t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
t2 = remove_duplicates(t1)

t2现在应该等于tensor([[1, 2, 3], [4, 5, 6]]) ,即删除1行和3行。 您知道执行此操作的方法吗?

我想用torch.unique做点什么,但我不知道该怎么做。

您可以简单地利用 torch.unique 的参数 dim。

t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6], [7,8,9]])
torch.unique(t1, dim=0)

这样你就得到了你想要的结果:

tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

在这里您可以阅读该参数的含义。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM