[英]Compute maxima and minima of a 4D tensor in PyTorch
Suppose that we have a 4-dimensional tensor, for instance假设我们有一个 4 维张量,例如
import torch
X = torch.rand(2, 3, 4, 4)
tensor([[[[-0.9951, 1.6668, 1.3140, 1.4274],
[ 0.2614, 2.6442, -0.3041, 0.7337],
[-1.2690, 0.0125, -0.3885, 0.0535],
[ 1.5270, -0.1186, -0.4458, 0.1389]],
[[ 0.9125, -1.2998, -0.4277, -0.2688],
[-1.6917, -0.8855, -0.2784, -0.6717],
[ 1.1417, 0.4574, 0.4803, -1.6637],
[ 0.7322, 0.2654, -0.1525, 1.7285]],
[[ 1.8310, -1.5765, 0.1392, 1.3431],
[-0.6641, -1.5090, -0.4893, -1.4110],
[ 0.5875, 0.7528, -0.6482, -0.2547],
[-2.3133, 0.3888, 2.1428, 0.2331]]]])
I want to compute the maximum and the minimum values of X
over the dimensions 2 and 3, that is, to compute two tensors of size (2,3,1,1), one for the maximum and one for the minimum values of the 4x4 blocks.我想计算
X
在维度 2 和维度 3 上的最大值和最小值,也就是说,计算两个大小为 (2,3,1,1) 的张量,一个用于最大值,一个用于最小值4x4 块。
I started by trying to do that with torch.max()
and torch.min()
, but I had no luck.我开始尝试使用
torch.max()
和torch.min()
来做到这一点,但我没有运气。 I would expect the dim
argument of the above functions to be able to take tuple values, but it can take only an integer. So I don't know how to proceed.我希望上述函数的
dim
参数能够采用元组值,但它只能采用 integer。所以我不知道如何进行。
However, specifically for the maximum values, I decided to use torch.nn.MaxPool2d()
with kernel_size=4
and stride=4
.但是,特别是对于最大值,我决定将
torch.nn.MaxPool2d()
与kernel_size=4
和stride=4
一起使用。 This indeed did the job:这确实完成了工作:
max_pool = nn.MaxPool2d(kernel_size=4, stride=4)
X_max = max_pool(X)
tensor([[[[2.6442]],
[[1.7285]],
[[2.1428]]]])
But, afaik, there's no similar layer for "min"-pooling.但是,afaik,“最小”池没有类似的层。 Could you please help me on how to compute the minima similarly to the maxima?
你能帮我看看如何像计算最大值一样计算最小值吗?
Thank you.谢谢你。
Just calculate the max for both dimensions sequentially, it gives the same result: 只需依次计算两个维度的最大值,即可得出相同的结果:
tup = (2,3)
for dim in tup:
X = torch.max(X,dim=dim,keepdim=True)[0]
If you use torch>=1.11, please use torch.amax function,如果使用torch>=1.11,请使用torch.amax function,
dim = (2,3)
x = torch.rand(2,3,4,4)
x_max = torch.amax(x,dim=dim)
However, if you use the older version of Pytorch, then please use this custom max function但是,如果你使用旧版本的Pytorch,那么请使用这个自定义的max function
def torch_max(x,dim):
s1 = [i for i in range(len(x.shape)) if i not in dim]
s2 = [i for i in range(len(x.shape)) if i in dim]
x2 = x.permute(tuple(s1+s2))
s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
x2 = torch.reshape(x2, tuple(s))
max,_ = x2.max(-1)
return max
Usage of this function is very similar to the original torch.max function.这个 function 的用法与原始的 torch.max function 非常相似。
dim = (2,3)
x = torch.rand(2,3,4,4)
x_max = torch_max(x,dim=dim)
If the length of dim
is long, then this custom torch_max
is slightly faster.如果
dim
的长度比较长,那么这个自定义的torch_max
稍微快一些。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.