简体   繁体   中英

PyTorch is tiling images when loaded with Dataloader

I am trying to load an Images Dataset using the PyTorch dataloader, but the resulting transformations are tiled, and don't have the original images cropped to the center as I am expecting them.

transform = transforms.Compose([transforms.Resize(224),
                             transforms.CenterCrop(224),
                             transforms.ToTensor()])

dataset = datasets.ImageFolder('ml-models/downloads/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


images, labels = next(iter(dataloader))
import matplotlib.pyplot as plt
plt.imshow(images[6].reshape(224, 224, 3))

The resulting image is tiled, and not center cropped.[![as shown in the Jupyter snapshot here][1]][1]

Is there something wrong in the provided transformation? (Image shown below on link: ) [1]: https://i.stack.imgur.com/HtrIa.png

Pytorch stores tensors in channel-first format, so a 3 channel image is a tensor of shape (3, H, W). Matplotlib expects data to be in channel-last format ie (H, W, 3). Reshaping does not rearrange the dimensions, for that you need Tensor.permute .

plt.imshow(images[6].permute(1, 2, 0))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM