简体   繁体   English

PyTorch torch.max 在多个维度

[英]PyTorch torch.max over multiple dimensions

Have tensor like: x.shape = [3, 2, 2] .有像这样的张量: x.shape = [3, 2, 2]

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

I need to take .max() over the 2nd and 3rd dimensions.我需要.max()在第二和第三维度上。 I expect some like this [-0.2632, -0.1453, -0.0274] as output.我希望像 output 这样[-0.2632, -0.1453, -0.0274] I tried to use: x.max(dim=(1,2)) , but this causes an error.我尝试使用: x.max(dim=(1,2)) ,但这会导致错误。

Now, you can do this.现在,你可以做到这一点。 The PR was merged (Aug 28) and it is now available in the nightly release. PR 已合并(8 月 28 日),现在可在每晚版本中使用。

Simply usetorch.amax() :只需使用torch.amax()

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(torch.amax(x, dim=(1, 2)))

# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])

Original Answer原始答案

As of today (April 11, 2020), there is no way to do .min() or .max() over multiple dimensions in PyTorch.截至今天(2020 年 4 月 11 日),在 PyTorch 中无法在多个维度上执行.min().max() There is an open issue about it that you can follow and see if it ever gets implemented.有一个关于它的未解决问题,您可以关注它并查看它是否得到实施。 A workaround in your case would be:在您的情况下,一种解决方法是:

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))

So, if you need only the values: x.view(x.size(0), -1).max(dim=-1).values .因此,如果您只需要以下值: x.view(x.size(0), -1).max(dim=-1).values

If x is not a contiguous tensor, then .view() will fail.如果x不是连续张量,则.view()将失败。 In this case, you should use .reshape() instead.在这种情况下,您应该改用.reshape()


Update August 26, 2020 2020 年 8 月 26 日更新

This feature is being implemented in PR#43092 and the functions will be called amin and amax .此功能正在PR#43092中实现,这些功能将被称为aminamax They will return only the values.他们将只返回值。 This is probably being merged soon, so you might be able to access these functions on the nightly build by the time you're reading this:) Have fun.这可能很快就会被合并,所以当您阅读本文时,您可能能够在夜间构建中访问这些功能:) 玩得开心。

Although the solution of Berriel solves this specific question, I thought adding some explanation might help everyone to shed some light on the trick that's employed here, so that it can be adapted for (m)any other dimensions.尽管Berriel 的解决方案解决了这个特定问题,但我认为添加一些解释可能会帮助每个人了解这里使用的技巧,以便它可以适用于 (m) 任何其他维度。

Let's start by inspecting the shape of the input tensor x :让我们首先检查输入张量x的形状:

In [58]: x.shape   
Out[58]: torch.Size([3, 2, 2])

So, we have a 3D tensor of shape (3, 2, 2) .所以,我们有一个形状为(3, 2, 2)的 3D 张量。 Now, as per OP's question, we need to compute maximum of the values in the tensor along both 1 st and 2 nd dimensions.现在,根据 OP 的问题,我们需要计算张量中沿第一维和第二维的maximum As of this writing, the torch.max() 's dim argument supports only int .在撰写本文时, torch.max()dim参数仅支持int So, we can't use a tuple.所以,我们不能使用元组。 Hence, we will use the following trick, which I will call as,因此,我们将使用以下技巧,我将其称为,

The Flatten & Max Trick : since we want to compute max over both 1 st and 2 nd dimensions, we will flatten both of these dimensions to a single dimension and leave the 0 th dimension untouched. Flatten & Max 技巧:因为我们想要计算第 1和第 2max ,所以我们将把这两个维度都展平为一个维度,并且保持第 0不变。 This is exactly what is happening by doing:这正是正在发生的事情:

In [61]: x.flatten().reshape(x.shape[0], -1).shape   
Out[61]: torch.Size([3, 4])   # 2*2 = 4

So, now we have shrinked the 3D tensor to a 2D tensor (ie matrix).所以,现在我们已经将 3D 张量缩小为二维张量(即矩阵)。

In [62]: x.flatten().reshape(x.shape[0], -1) 
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
        [-0.1821, -0.1747, -0.1526, -0.1453],
        [-0.0642, -0.0568, -0.0347, -0.0274]])

Now, we can simply apply max over the 1 st dimension (ie in this case, first dimension is also the last dimension), since the flattened dimensions resides in that dimension.现在,我们可以简单地将max应用于第一个维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度位于该维度中。

In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1)    # or: `dim = -1`
Out[65]: 
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))

We got 3 values in the resultant tensor since we had 3 rows in the matrix.由于矩阵中有 3 行,因此我们在结果张量中得到了 3 个值。


Now, on the other hand if you want to compute max over 0 th and 1 st dimensions, you'd do:现在,另一方面,如果你想计算第 0和第 1max ,你会这样做:

In [80]: x.flatten().reshape(-1, x.shape[-1]).shape 
Out[80]: torch.Size([6, 2])    # 3*2 = 6

In [79]: x.flatten().reshape(-1, x.shape[-1]) 
Out[79]: 
tensor([[-0.3000, -0.2926],
        [-0.2705, -0.2632],
        [-0.1821, -0.1747],
        [-0.1526, -0.1453],
        [-0.0642, -0.0568],
        [-0.0347, -0.0274]])

Now, we can simply apply max over the 0 th dimension since that is the result of our flattening.现在,我们可以简单地在第 0上应用max ,因为这是我们展平的结果。 ((also, from our original shape of ( 3, 2, 2 ), after taking max over first 2 dimensions, we should get two values as result.) ((同样,从我们的原始形状( 3, 2, 2 ),在前两个维度取 max 之后,我们应该得到两个值作为结果。)

In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0) 
Out[82]: 
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))

In a similar vein, you can adapt this approach to multiple dimensions and other reduction functions such as min .同样,您可以将此方法应用于多维和其他缩减函数,例如min


Note : I'm following the terminology of 0-based dimensions ( 0, 1, 2, 3, ... ) just to be consistent with PyTorch usage and the code.注意:我遵循基于 0 的维度 ( 0, 1, 2, 3, ... ) 的术语,只是为了与 PyTorch 的用法和代码保持一致。

If you only want to use the torch.max() function to get the indices of the max entry in a 2D tensor, you can do:如果您只想使用torch.max() function 来获取 2D 张量中最大条目的索引,您可以执行以下操作:

max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)

In testing, the above printed out for me:在测试中,上面为我打印出来:

max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)   
max_i_indices: tensor([ 5,  8, 10,  6, 13, 14,  5,  6,  6,  6, 13,  4, 13, 13, 11])  
max_j_index:  tensor(6)  
max_index:  [tensor(5), tensor(6)]

This approach can be extended for 3 dimensions.这种方法可以扩展到 3 个维度。 While not as visually pleasing as other answers in this post, this answer shows that the problem can be solved using only the torch.max() function (though I do agree built-in support for torch.max() over multiple dimensions would be a boon).虽然不像这篇文章中的其他答案那样视觉上令人愉悦,但这个答案表明问题可以仅使用torch.max() function 来解决(尽管我同意在多个维度上对torch.max()的内置支持将是一个福音)。

FOLLOW UP跟进
I stumbled upon a similar question in the PyTorch forums and the poster ptrblck offered this line of code as a solution for getting the indices of the maximal entry in the tensor x:在 PyTorch 论坛中偶然发现了一个类似的问题,并且发布者ptrblck 提供了这行代码作为获取张量 x 中最大条目索引的解决方案:

x = (x==torch.max(x)).nonzero()

Not only does this one-liner work with N-dimensional tensors without needing adjustments to the code, but it is also much faster than the approach I wrote of above (at least 2:1 ratio) and faster than the accepted answer (about 3:2 ratio) according to my benchmarks.这种单线不仅可以处理 N 维张量而不需要对代码进行调整,而且它也比我上面写的方法(至少 2:1 的比率)快得多,并且比公认的答案(大约 3 :2 比率)根据我的基准。

If you use torch <= 1.10, please use this custom max function如果您使用 torch <= 1.10,请使用此自定义最大值 function

def torch_max(x,dim):
    s1 = [i for i in range(len(x.shape)) if i not in dim] 
    s2 = [i for i in range(len(x.shape)) if i in dim]
    x2 = x.permute(tuple(s1+s2))
    s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
    x2 = torch.reshape(x2, tuple(s))
    max,_ = x2.max(-1)
    return max 

Then run like然后像这样运行

print(torch_max(x, dim=(1, 2)))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM