[英]How can I fold a Tensor that I unfolded with PyTorch that has overlap?
I have a Tensor of size: torch.Size([1, 63840])
which I then unrolled:我有一个大小为:
torch.Size([1, 63840])
的张量,然后展开:
inp_unfolded = inp_seq.unfold(1, 160, 80)
that gives me a shape of: torch.Size([1, 797, 160])
这给了我一个形状:
torch.Size([1, 797, 160])
How can I re- fold
that to get a Tensor of torch.Size([1, 63840])
?我怎样才能重新
fold
它以获得torch.Size([1, 63840])
的张量?
For that specific configuration, since 63840
is divisible by 160
and the step size is a multiple of the slice size, you can simply select every second element along that dimension and then flatten
the resulting tensor:对于该特定配置,由于
63840
可以被160
整除,并且步长是切片大小的倍数,因此您可以简单地 select 沿该维度的每个第二个元素,然后flatten
生成的张量:
inp_unfolded[:, ::2, :].flatten(1, 2)
More generally, for t.unfold(i, n, s)
, if t.shape[i] % n == 0 and n % s == 0
holds, then you can restore the original tensor via:更一般地说,对于
t.unfold(i, n, s)
,如果t.shape[i] % n == 0 and n % s == 0
成立,那么您可以通过以下方式恢复原始张量:
index = [slice(None) for __ in t.shape]
index[i] = slice(None, None, n // s)
original = t.unfold(i, n, s)[tuple(index)].flatten(i, i+1)
Of course you can also use slice notation, if the dimension i
is known beforehand.当然,如果事先知道维度
i
,您也可以使用切片表示法。 For example i == 1
as in your example:例如
i == 1
如您的示例所示:
original = t.unfold(1, n, s)[:, ::n//s, ...].flatten(1, 2)
Well, actually the conditions, given t.unfold(i, n, s)
are:好吧,实际上给定
t.unfold(i, n, s)
的条件是:
n >= s
(otherwise step is skipping some original data and we cannot restore it) n >= s
(否则步骤会跳过一些原始数据,我们无法恢复它)n + s <= t.shape[i]
Then we can do it via:然后我们可以通过:
def roll(x, n, s, axis=1):
return torch.cat((p[0], p[1:][:, n-s:].flatten()), axis)
explanation:解释:
p[0]
is the starting chunk that is always unique at start p[0]
是起始块,在开始时总是唯一的
p[1:][:, ns:]
- then, we take rest of rolls and ns
depict how many elements will overlap between rolls so we want to ignore them and take only those from ns
p[1:][:, ns:]
- 然后,我们取卷的 rest 和ns
描述卷之间有多少元素重叠,所以我们想忽略它们,只取来自ns
的元素
ilustration:插图:
x.unfold(0, 5, 2)
tensor([[ 1., 2., 3., 4., 5.],
[ 3., 4., 5., 6., 7.], # 3, 4, 5 are repeated
[ 5., 6., 7., 8., 9.], # 5, 6, 7 are repeated...
[ 7., 8., 9., 10., 11.],
[ 9., 10., 11., 12., 13.],
[11., 12., 13., 14., 15.],
[13., 14., 15., 16., 17.]])
example:例子:
>> x = torch.arange(1., 18)
>> p = x.unfold(0, 5, 2)
>> roll(p, 5, 2, 0)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16., 17.])
you can also try it with你也可以试试
x = torch.arange(1., 18).reshape(1, 17)
and axis 1和轴 1
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.