简体   繁体   English

如何使用滑动窗口调整 PyTorch 张量的大小?

[英]How can I resize a PyTorch tensor with a sliding window?

I have a tensor with size: torch.Size([118160, 1]) .我有一个大小为: torch.Size([118160, 1])的张量。 What I want to do is split it up into n tensors with 100 elements each, sliding by 50 elements at a time.我想要做的是将它分成 n 个张量,每个张量有 100 个元素,一次滑动 50 个元素。 What's the best way to achieve this with PyTorch?使用 PyTorch 实现这一目标的最佳方法是什么?

Your number of elements needs to be divisble by 100. If this is not the case you can adjust with padding.您的元素数量需要被 100 整除。如果不是这种情况,您可以使用填充进行调整。

You can first do a split on the original list.您可以先对原始列表进行拆分。 Then do a split on the list where the first 50 elements are removed from the original list.然后对列表进行拆分,从原始列表中删除前 50 个元素。 You can then sample alternating order from A and B if you want to preserve original order.如果您想保留原始顺序,则可以从 A 和 B 中采样交替顺序。

A = yourtensor
B = yourtensor[50:] + torch.zeros(50,1)
A_ = A.view(100,-1)
B_ = B.view(100,-1)

A possible solution is:一个可能的解决方案是:

window_size = 100
stride = 50
splits = [x[i:min(x.size(0),i+window_size)] for i in range(0,x.size(0),stride)]

However, the last few elements will be shorter than window_size .但是,最后几个元素将比window_size短。 If this is undesired, you can do:如果这是不希望的,您可以执行以下操作:

splits = [x[i:i+window_size] for i in range(0,x.size(0)-window_size+1,stride)]

EDIT:编辑:

A more readable solution:一个更易读的解决方案:

# if keep_short_tails is set to True, the slices shorter than window_size at the end of the result will be kept 
def window_split(x, window_size=100, stride=50, keep_short_tails=True):
  length = x.size(0)
  splits = []

  if keep_short_tails:
    for slice_start in range(0, length, stride):
      slice_end = min(length, slice_start + window_size)
      splits.append(x[slice_start:slice_end])
  else:
    for slice_start in range(0, length - window_size + 1, stride):
      slice_end = slice_start + window_size
      splits.append(x[slice_start:slice_end])

  return splits

You can use Pytorch's unfold API.您可以使用 Pytorch 的展开 API。 Refer this https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html请参阅此https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html

Example:例子:

x = torch.arange(1., 20)
x.unfold(0,4,2)

tensor([[ 1.,  2.,  3.,  4.],  
        [ 3.,  4.,  5.,  6.],  
        [ 5.,  6.,  7.,  8.],  
        [ 7.,  8.,  9., 10.],  
        [ 9., 10., 11., 12.],  
        [11., 12., 13., 14.],  
        [13., 14., 15., 16.],  
        [15., 16., 17., 18.]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM