简体   繁体   English

如何在 pytorch 的组中做 argmax?

[英]How to do argmax in group in pytorch?

Is there any ways to implement maxpooling according to norm of sub vectors in a group in Pytorch?有没有办法根据 Pytorch 中的一组子向量的范数来实现最大池化? Specifically, this is what I want to implement:具体来说,这就是我想要实现的:

Input :输入

x : a 2-D float tensor, shape #Nodes * dim x :二维浮点张量,形状#Nodes * dim

cluster : a 1-D long tensor, shape #Nodes cluster :一维长张量,形状#Nodes

Output : Output

y , a 2-D float tensor, and: y ,一个二维浮点张量,并且:

y[i]=x[k] where k=argmax_{cluster[k]=i}(torch.norm(x[k],p=2)) . y[i]=x[k] 其中 k=argmax_{cluster[k]=i}(torch.norm(x[k],p=2))

I tried torch.scatter with reduce="max" , but this only works for dim=1 and x[i]>0 .我用reduce="max"尝试torch.scatter ,但这仅适用于dim=1 and x[i]>0

Can someone help me to solve the problem?有人可以帮我解决问题吗?

I don't think there's any built-in function to do what you want.我不认为有任何内置的 function 可以做你想做的事。 Basically this would be some form of scatter_reduce on the norm of x , but instead of selecting the max norm you want to select the row corresponding to the max norm.基本上这将是某种形式的 scatter_reduce 在x的范数上,但不是选择你想要的最大范数 select 对应于最大范数的行。

A straightforward implementation may look something like this一个简单的实现可能看起来像这样

"""
input
    x: float tensor of size [NODES, DIMS]
    cluster: long tensor of size [NODES]
output
    float tensor of size [cluster.max()+1, DIMS]
"""

num_clusters = cluster.max().item() + 1
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id in torch.unique(cluster):
    x_cluster = x[cluster == cluster_id]
    y[cluster_id] = x_cluster[torch.argmax(torch.norm(x_cluster, dim=1), dim=0)]

Which should work just fine if clusters.max() is relatively small.如果clusters.max()相对较小,应该可以正常工作。 If there are many clusters though then this approach has to unnecessarily create masks over cluster for every unique cluster id.如果有很多集群,那么这种方法必须为每个唯一的集群 id 不必要地在cluster上创建掩码。 To avoid this you can make use of argsort .为避免这种情况,您可以使用argsort The best I could come up with in pure python was the following.在纯 python 中我能想到的最好的方法如下。

num_clusters = cluster.max().item() + 1
x_norm = torch.norm(x, dim=1)

cluster_sortidx = torch.argsort(cluster)
cluster_ids, cluster_counts = torch.unique_consecutive(cluster[cluster_sortidx], return_counts=True)

end_indices = torch.cumsum(cluster_counts, dim=0).cpu().tolist()
start_indices = [0] + end_indices[:-1]

y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id, a, b in zip(cluster_ids, start_indices, end_indices):
    indices = cluster_sortidx[a:b]
    y[cluster_id] = x[indices[torch.argmax(x_norm[indices], dim=0)]]

For example in random tests with NODES = 60000 , DIMS = 512 , cluster.max()=6000 the first version takes about 620ms whie the second version takes about 78ms.例如在NODES = 60000 , DIMS = 512 , cluster.max()=6000的随机测试中,第一个版本大约需要 620 毫秒,而第二个版本大约需要 78 毫秒。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM