简体   繁体   中英

How can I resize a PyTorch tensor with a sliding window?

I have a tensor with size: torch.Size([118160, 1]) . What I want to do is split it up into n tensors with 100 elements each, sliding by 50 elements at a time. What's the best way to achieve this with PyTorch?

Your number of elements needs to be divisble by 100. If this is not the case you can adjust with padding.

You can first do a split on the original list. Then do a split on the list where the first 50 elements are removed from the original list. You can then sample alternating order from A and B if you want to preserve original order.

A = yourtensor
B = yourtensor[50:] + torch.zeros(50,1)
A_ = A.view(100,-1)
B_ = B.view(100,-1)

A possible solution is:

window_size = 100
stride = 50
splits = [x[i:min(x.size(0),i+window_size)] for i in range(0,x.size(0),stride)]

However, the last few elements will be shorter than window_size . If this is undesired, you can do:

splits = [x[i:i+window_size] for i in range(0,x.size(0)-window_size+1,stride)]

EDIT:

A more readable solution:

# if keep_short_tails is set to True, the slices shorter than window_size at the end of the result will be kept 
def window_split(x, window_size=100, stride=50, keep_short_tails=True):
  length = x.size(0)
  splits = []

  if keep_short_tails:
    for slice_start in range(0, length, stride):
      slice_end = min(length, slice_start + window_size)
      splits.append(x[slice_start:slice_end])
  else:
    for slice_start in range(0, length - window_size + 1, stride):
      slice_end = slice_start + window_size
      splits.append(x[slice_start:slice_end])

  return splits

You can use Pytorch's unfold API. Refer this https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html

Example:

x = torch.arange(1., 20)
x.unfold(0,4,2)

tensor([[ 1.,  2.,  3.,  4.],  
        [ 3.,  4.,  5.,  6.],  
        [ 5.,  6.,  7.,  8.],  
        [ 7.,  8.,  9., 10.],  
        [ 9., 10., 11., 12.],  
        [11., 12., 13., 14.],  
        [13., 14., 15., 16.],  
        [15., 16., 17., 18.]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM