简体   繁体   中英

What is the fastest way to use Pytorch Conv2d to apply multiple filters to every layer in multiple CT scans?

Assume an input data set contains CT scans of 100 patients, each scan containing 16 layers and each layer containing 512 x 512 pixels. I want to apply eight 3x3 convolution filters to each layer in every CT scan. So, the input array has shape [100, 16, 512, 512] and the kernels array has shape [8, 3, 3]. After the convolutions are applied, the goal is an output array with a shape [100, 16, 8, 512, 512]. The following code uses Pytorch Conv2d function to achieve this; however, I want to know if the groups parameter (and/or other means) can somehow eliminate the need for the loop.

 for layer_index in range(0, number_of_layers):
    # Getting current ct scan layer for all patients
    # ct_scans dimensions are:  [patient, scan layer, pixel row, pixel column]
    # ct_scans shape: [100, 16, 512, 512]
    image_stack = ct_scans[:, layer_index, :, :]
    # Converting from numpy to tensor format
    image_stack_t = torch.from_numpy(image[:, None, :, :])
    # Applying convolution to create 8 filtered versions of current scan layer across all patients
    # shape of kernels is: [8, 3, 3]
    filtered_image_stack_t = conv2d(image_stack_t, kernels, padding=1, groups=1)
    # Converting from tensor format back to numpy format
    filtered_image_stack = filtered_image_stack_t.numpy()
    # Amassing filtered ct scans for all patients back into one array
    # filtered_ct_scans dimensions are: [patient, ct scan layer, filter number, pixel row, pixel column]
    # filtered_ct_scans shape is: [100, 16, 8, 512, 512]
    filtered_ct_scans[:, layer_index, :, :, :] = filtered_image_stack

So far, my attempts to use anything other than groups=1 leads to errors. I also found the following similar posts; however, they don't address my specific question.

How to use groups parameter in PyTorch conv2d function with batch?

How to use groups parameter in PyTorch conv2d function

You do not need to use grouped convolutions. Resizing you input appropriately is all that is needed.

import torch
import torch.nn.functional as F

ct_scans = torch.randn((100,16,512,512))
kernels = torch.randn((8,1,3,3))

B,L,H,W = ct_scans.shape #(batch,layers,height,width)
ct_scans = ct_scans.view(-1,H,W)
ct_scans.unsqueeze_(1)
out = F.conv2d(ct_scans, kernels)
out = out.view(B,L,*out.shape[1:])
print(out)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM