[英]Delete duplicated rows in torch.tensor
I have a torch.tensor
of shape (n,m)
and I want to remove the duplicated rows (or at least find them).我有一个形状为
(n,m)
的torch.tensor
,我想删除重复的行(或至少找到它们)。 For example:例如:
t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
t2 = remove_duplicates(t1)
t2
should be now equal to tensor([[1, 2, 3], [4, 5, 6]])
, that is rows 1
and 3
are removed. t2
现在应该等于tensor([[1, 2, 3], [4, 5, 6]])
,即删除1
行和3
行。 Do you know a way to perform this operation?您知道执行此操作的方法吗?
I was thinking to do something with torch.unique
but I cannot figure out what to do.我想用
torch.unique
做点什么,但我不知道该怎么做。
You can simply exploit the parameter dim of torch.unique.您可以简单地利用 torch.unique 的参数 dim。
t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6], [7,8,9]])
torch.unique(t1, dim=0)
In this way you obtain the result you want:这样你就得到了你想要的结果:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Here you can read the meaning of that parameter. 在这里您可以阅读该参数的含义。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.