I'm trying to load a custom dataset for training a neural network, but before I load them in, I would like to verify that they've been loaded correctly. So far it looks like they are not being loaded correctly, but I can't figure out what gives the images the format that they get.
This is the code that I'm loading the images with, and then displaying them.
f, axarr = plt.subplots(2,2, figsize=(20,20))
def load_dataset():
data_path = 'processedData/HE/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = DataLoader(
train_dataset, batch_size=64
)
return train_loader
x_train = load_dataset()
datathing = next(iter(x_train))
for i, ax in enumerate(axarr.flat):
ax.imshow(datathing[0][i].view(128,128,3))
ax.axis('off')
plt.show()
When running this with the images, the output looks like this.
It is suppose to look something like these images
I have been trying with different image datasets, but all the sets return the same format, so my question is:
The .view(128, 128, 3)
is messing up with the images.
As you can read in the documentation of the transformation .ToTensor(...)
:
[...]
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
[...]
That is, the channel dimension is moved from the last to the first dimension. You can see that in the source code :
# ...
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# ...
Therefore, you cannot simply call .view(...)
; you have to transpose it back. In PyTorch, you can use the .permute(...)
function for that. Something like this:
ax.imshow(datathing[0][i].permute(1, 2, 0))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.