简体   繁体   中英

Getting Bad Images After Data Augmentation in PyTorch

I'm working on a nuclear segmentation problem where I'm trying to identify the positions of nuclei in images of stained tissues. The given training dataset has a picture of the stained tissue and a mask with the nuclei positions. Since the dataset was small, I wanted to try data augmentation in PyTorch, but after doing that, for some reason, when I output my mask image, it looks fine, but the corresponding tissue image is incorrect.

All my training images are in X_train with shape (128, 128, 3) , the corresponding masks in Y_train with shape (128, 128, 1) and similarly the cross-validation images and masks in X_val and Y_val respectively.

Y_train and Y_val have dtype = np.bool , X_train and X_val have dtype = np.uint8 .

Before data augmentation, I check my images like this:

fig, axis = plt.subplots(2, 2)
axis[0][0].imshow(X_train[0].astype(np.uint8))
axis[0][1].imshow(np.squeeze(Y_train[0]).astype(np.uint8))
axis[1][0].imshow(X_val[0].astype(np.uint8))
axis[1][1].imshow(np.squeeze(Y_val[0]).astype(np.uint8))

The output is as follows: Before Data Augmentation

For the data augmentation, I define a custom class as follows:

Here I have imported torchvision.transforms.functional as TF and torchvision.transforms as transforms . images_np and masks_np are the inputs which are numpy arrays.

class Nuc_Seg(Dataset):
def __init__(self, images_np, masks_np):
    self.images_np = images_np
    self.masks_np = masks_np

def transform(self, image_np, mask_np):
    ToPILImage = transforms.ToPILImage()
    image = ToPILImage(image_np)
    mask = ToPILImage(mask_np.astype(np.int32))

    angle = random.uniform(-10, 10)
    width, height = image.size
    max_dx = 0.2 * width
    max_dy = 0.2 * height
    translations = (np.round(random.uniform(-max_dx, max_dx)), np.round(random.uniform(-max_dy, max_dy)))
    scale = random.uniform(0.8, 1.2)
    shear = random.uniform(-0.5, 0.5)
    image = TF.affine(image, angle = angle, translate = translations, scale = scale, shear = shear)
    mask = TF.affine(mask, angle = angle, translate = translations, scale = scale, shear = shear)

    image = TF.to_tensor(image)
    mask = TF.to_tensor(mask)
    return image, mask

def __len__(self):
    return len(self.images_np)

def __getitem__(self, idx):
    image_np = self.images_np[idx]
    mask_np = self.masks_np[idx]
    image, mask = self.transform(image_np, mask_np)

    return image, mask    

This is followed by:

I have used from torch.utils.data import DataLoader

train_dataset = Nuc_Seg(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_dataset = Nuc_Seg(X_val, Y_val)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = True)

After this step, I try to check on my first set training image and mask using this:

%matplotlib inline

for ex_img, ex_mask in train_loader:

    img = ex_img[0]
    img = img.reshape(128, 128, 3)
    mask = ex_mask[0]
    mask = mask.reshape(128, 128)

    img = img.numpy()
    mask = mask.numpy()

    fig, (axis_1, axis_2) = plt.subplots(1, 2)
    axis_1.imshow(img.astype(np.uint8))
    axis_2.imshow(mask.astype(np.uint8))

    break

I get this as my output: After Data Augmentation 1

When I change the line axis_1.imshow(img.astype(np.uint8)) to axis_1.imshow(img) ,

I get this image: After Data Augmentation 2

The images of the mask are correct, but for some reason, the images of nuclei are wrong. With the .astype(np.uint8) , the tissue image is completely black.

Without the .astype(np.uint8) , the positions of the nuclei are correct, but the color scheme is all messed up (I expect images like those seen before data augmentation, either gray-ish or pink-ish), plus 9 copies of the same image in a grid are displayed for some reason. Can you please help me get the correct output of the tissue images?

You are converting the images to PyTorch tensors, and in PyTorch the images have size [C, H, W] . When you are visualising them, you are converting the tensors back to NumPy arrays, where images have size [H, W, C] . Therefore you are trying to rearrange the dimensions, but you are using torch.reshape , which doesn't swap the dimensions, but only partitions the data in a different way.

An example makes this clearer:

# Incrementing numbers with size 2 x 3 x 3
image = torch.arange(2 * 3 * 3).reshape(2, 3, 3)
# => tensor([[[ 0,  1,  2],
#             [ 3,  4,  5],
#             [ 6,  7,  8]],
#
#            [[ 9, 10, 11],
#             [12, 13, 14],
#             [15, 16, 17]]])

# Reshape keeps the same order of elements but for a different size
# The numbers are still incrementing from left to right
image.reshape(3, 3, 2)
# => tensor([[[ 0,  1],
#             [ 2,  3],
#             [ 4,  5]],
#
#            [[ 6,  7],
#             [ 8,  9],
#             [10, 11]],
#
#            [[12, 13],
#             [14, 15],
#             [16, 17]]])

To reorder the dimensions you can use permute :

# Dimensions are swapped
# Now the numbers increment from top to bottom
image.permute(1, 2, 0)
# => tensor([[[ 0,  9],
#             [ 1, 10],
#             [ 2, 11]],
#
#            [[ 3, 12],
#             [ 4, 13],
#             [ 5, 14]],
#
#            [[ 6, 15],
#             [ 7, 16],
#             [ 8, 17]]])

With the .astype(np.uint8) , the tissue image is completely black.

PyTorch images are represented as floats with values between [0, 1], but NumPy uses integer values between [0, 255]. Casting the float values to np.uint8 will result in only 0s and 1s, where everything that was not equal to 1, will be set to 0, therefore the whole image is black.

You need to multiply the values by 255 to bring them into range of [0, 255].

img = img.permute(1, 2, 0) * 255
img = img.numpy().astype(np.uint8)

This conversion is also automatically done when you are converting a tensor to a PIL image with transforms.ToPILImage (or with TF.to_pil_image if you prefer the functional version) and the PIL image can be converted directly to a NumPy array. With that you don't have to worry about the dimensions, value ranges or type and the code above can be replaced with:

img = np.array(TF.to_pil_image(img))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM