简体   繁体   English

调整 PyTorch 张量的大小

[英]Resize PyTorch Tensor

I am currently using the tensor.resize() function to resize a tensor to a new shape t = t.resize(1, 2, 3) .我目前正在使用 tensor.resize() 函数将张量调整为新形状t = t.resize(1, 2, 3)

This gives me a deprecation warning:这给了我一个弃用警告:

non-inplace resize is deprecated不推荐使用非就地调整大小

Hence, I wanted to switch over to the tensor.resize_() function, which seems to be the appropriate in-place replacement.因此,我想切换到tensor.resize_()函数,这似乎是适当的就地替换。 However, this leaves me with an然而,这给我留下了一个

cannot resize variables that require grad无法调整需要 grad 的变量的大小

error.错误。 I can fall back to我可以回到

from torch.autograd._functions import Resize
Resize.apply(t, (1, 2, 3))

which is what tensor.resize() does in order to avoid the deprecation warning.这就是 tensor.resize() 为了避免弃用警告所做的。 This doesn't seem like an appropriate solution but rather a hack to me.这对我来说似乎不是一个合适的解决方案,而是一个黑客。 How do I correctly make use of tensor.resize_() in this case?在这种情况下,如何正确使用tensor.resize_()

You can instead choose to go with tensor.reshape(new_shape) or torch.reshape(tensor, new_shape) as in:您可以选择使用tensor.reshape(new_shape)torch.reshape(tensor, new_shape)如下所示:

# a `Variable` tensor
In [15]: ten = torch.randn(6, requires_grad=True)

# this would throw RuntimeError error
In [16]: ten.resize_(2, 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-094491c46baa> in <module>()
----> 1 ten.resize_(2, 3)

RuntimeError: cannot resize variables that require grad

The above RuntimeError can be resolved or avoided by using tensor.reshape(new_shape)可以通过使用tensor.reshape(new_shape)解决或避免上述RuntimeError

In [17]: ten.reshape(2, 3)
Out[17]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

# yet another way of changing tensor shape
In [18]: torch.reshape(ten, (2, 3))
Out[18]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

Please can you try something like:请你试试这样的:

import torch
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(":::",x.resize_(2, 2))
print("::::",x.resize_(3, 3))

Simply use t = t.contiguous().view(1, 2, 3) if you don't really want to change its data.如果您真的不想更改其数据t = t.contiguous().view(1, 2, 3)只需使用t = t.contiguous().view(1, 2, 3)

If not the case, the in-place resize_ operation will break the grad computation graph of t .如果不是这种情况,就地resize_操作将破坏t的梯度计算图。
If it doesn't matter to you, just use t = t.data.resize_(1,2,3) .如果对您来说无关紧要,只需使用t = t.data.resize_(1,2,3)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM