简体   繁体   中英

Extracting 3D patches from 3D images in both overlapping and nonoverlapping process and recovering the image back

I am working with 172x220x156 shaped 3D images. To feed the image into the network for output I need to extract patches of size 32x32x32 from the image and add those back to get the image again. Since my image dimension are not multiples of patch size thus I have to get overlapping patches. I want to know how to do that.

I am working in PyTorch, there are some options like unfold and fold but I am not sure how they work.

Is all your data exactly 172x220x156 ? If so, seems like you could just use a for loop and index into the tensor to get 32x32x32 blocks, correct? (possibly hardcoding a few things).

However, I'm not able to totally answer the question because it's not clear how you want to combine the results. To be clear, is this your goal?

1) get a 32x32x32 patch from an image 2) do some arbitrary processing on it 3) save that patch to some result at the correct index 4) repeat

If so, how do you plan on combining the overlapping patches? Sum them? Average them?

However - the indexing:

out_tensor = torch.zeros_like(input)
for i_idx in [0, 32, 64, 96, 128, 140]:
    for j_idx in [0, 32, 64, 96, 128, 160, 188]:
        for k_idx in [0, 32, 64, 96, 124]:
            input = tensor[i_idx, j_idx, k_idx]
            output = your_model(input)
            out_tensor[i_idx, j_idx, k_idx] = output

this isn't optimized at all, but I imagine the bulk of the computation will be the actual neural network, and there's no way around that, so optimization might be pointless.

You can use unfold ( pytorch docs ):

batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

kernel_c, kernel_h, kernel_w = 32, 32, 32
step = 32

# Tensor.unfold(dimension, size, step)
windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
print(windows_unpacked.shape)
# result: torch.Size([1, 5, 6, 4, 32, 32, 32])

windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([120, 32, 32, 32])

To extract (overlapping-) patches and to reconstruct the input shape we can use the torch.nn.functional.unfold and the inverse operation torch.nn.functional.fold . These methods only process 4D tensors or 2D images, however you can use these methods to process one dimension at a time.

Few notes:

  1. This way requires fold/unfold methods from pytorch , unfortunately I have yet to find a similar method in the TF api.

  2. We can extract patches in 2 ways, their output is the same. The methods are called extract_patches_3d and extract_patches_3ds where X is the number of dimensions. The latter uses torch.Tensor.unfold() and has less lines of code. (output is the same, except it cannot use dilation)

  3. The methods extract_patches_Xd and combine_patches_Xd are inverse methods and the combiner reverses the steps from the extracter step by step.

  4. The lines of code are followed by a comment stating the dimensionality such as (B, C, D, H, W). The following are used:

    1. B : Batch size
    2. C : Channels
    3. D : Depth Dimension
    4. H : Height Dimension
    5. W : Width Dimension
    6. x_dim_in : In the extraction method, this is the number input pixels in dimension x . In the combining method, this is the number of number of sliding windows in dimension x .
    7. x_dim_out : In the extraction method, this is the number of sliding windows in dimension x . In the combining method, this is the number output pixels in dimension x .
  5. I have a public notebook to try out the code

  6. The get_dim_blocks() method is the function given on the pytorch docs website to compute the output shape of a convolutional layer.

  7. Note that if you have overlapping patches and you combine them, the overlapping elements will be summed. If you would like to get the initial input again there is a way.

    1. Create similar sized tensor of ones as the patches with torch.ones_like(patches_tensor) .
    2. Combine the patches into full image with same output shape. (this creates a counter for overlapping elements).
    3. Divide the Combined image with the Combined ones, this should reverse any double summation of elements. (3D): We need to use 2 fold and unfold where we first apply the fold to the D dimension and leave the W and H untouched by setting kernel to 1, padding to 0, stride to 1 and dilation to 1. After we review the tensor and fold over the H and W dimensions. The unfolding happens in reverse, starting with H and W , then D .
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))

Output (3D)

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   1.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  1.,   2.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  2.,   3.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  3.,   4.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  4.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   1.],
           [  0.,   5.]]]],



        ...,



        [[[[124.,   0.],
           [128.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0., 125.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[125., 126.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[126., 127.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[127., 128.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[128.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[   8.,   16.,   24.,   32.],
           [  40.,   48.,   56.,   64.],
           [  72.,   80.,   88.,   96.],
           [ 104.,  112.,  120.,  128.]],

          [[ 136.,  144.,  152.,  160.],
           [ 168.,  176.,  184.,  192.],
           [ 200.,  208.,  216.,  224.],
           [ 232.,  240.,  248.,  256.]]],


         [[[ 264.,  272.,  280.,  288.],
           [ 296.,  304.,  312.,  320.],
           [ 328.,  336.,  344.,  352.],
           [ 360.,  368.,  376.,  384.]],

          [[ 392.,  400.,  408.,  416.],
           [ 424.,  432.,  440.,  448.],
           [ 456.,  464.,  472.,  480.],
           [ 488.,  496.,  504.,  512.]]]],



        [[[[ 520.,  528.,  536.,  544.],
           [ 552.,  560.,  568.,  576.],
           [ 584.,  592.,  600.,  608.],
           [ 616.,  624.,  632.,  640.]],

          [[ 648.,  656.,  664.,  672.],
           [ 680.,  688.,  696.,  704.],
           [ 712.,  720.,  728.,  736.],
           [ 744.,  752.,  760.,  768.]]],


         [[[ 776.,  784.,  792.,  800.],
           [ 808.,  816.,  824.,  832.],
           [ 840.,  848.,  856.,  864.],
           [ 872.,  880.,  888.,  896.]],

          [[ 904.,  912.,  920.,  928.],
           [ 936.,  944.,  952.,  960.],
           [ 968.,  976.,  984.,  992.],
           [1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
tensor(True)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM