简体   繁体   English

将 PyTorch 张量拆分为重叠的块

[英]Split PyTorch tensor into overlapping chunks

Given a batch of images of shape (batch, c, h, w), I want to reshape it into (-1, depth, c, h, w) such that the i-th "chunk" of size d contains frames i -> i+d.给定一批形状为 (batch, c, h, w) 的图像,我想将其重塑为 (-1, depth, c, h, w) 以使大小为 d 的第 i 个“块”包含帧 i -> i+d。 Basically, using .view(-1, d, c, h, w) would reshape the tensor into d-size chunks where the index of the first image would be a multiple of d, which isnt what I want.基本上,使用 .view(-1, d, c, h, w) 会将张量重塑为 d 大小的块,其中第一张图像的索引将是 d 的倍数,这不是我想要的。

Scalar example:标量示例:

if the original tensor is something like:如果原始张量类似于:

[1,2,3,4,5,6,7,8,9,10,11,12] and d is 2; 

view() would return : [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]; view()将返回: [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]];

however, I want to get:但是,我想得到:

[[1,2],[2,3],[3,4],[4,5],[5,6],[6,7],[7,8],[8,9],[9,10],[10,11],[11,12]]

I wrote this function to do so:我写了这个函数来做到这一点:

def chunk_slicing(data, depth):
    output = []
    for i in range(data.shape[0] - depth+1):
        temp = data[i:i+depth]
        output.append(temp)
    return torch.Tensor(np.array([t.numpy() for t in output]))

However I need a function that is useable as part of a PyTorch model as this function causes this error :但是,我需要一个可用作 PyTorch 模型一部分的函数,因为此函数会导致此错误:

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

IIUC, You need torch.Tensor.unfold . IIUC,你需要torch.Tensor.unfold

import torch
x = torch.arange(1, 13)
x.unfold(dimension = 0,size = 2, step = 1)

tensor([[ 1,  2],
        [ 2,  3],
        [ 3,  4],
        [ 4,  5],
        [ 5,  6],
        [ 6,  7],
        [ 7,  8],
        [ 8,  9],
        [ 9, 10],
        [10, 11],
        [11, 12]])

Another example with size = 3 and step = 2 .另一个size = 3step = 2的例子。

>>> torch.arange(1, 10).unfold(dimension = 0,size = 3, step = 2)

tensor([[1, 2, 3],  # window with size = 3
# step : ---1--2---
        [3, 4, 5],  # 'step = 2' so start from 3
        [5, 6, 7],
        [7, 8, 9]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM