简体   繁体   中英

Displaying images loaded with pytorch dataloader

I am working with some lidar data images that I cannot post here due to a reputation restriction on posting images. However, when loading the same images using pytorch ImageFolder and Dataloader with the only transform being converting the images to tensors there seems to be some extreme thresholding and I can't seem to locate the cause of this.

Below is how I'm displaying the first image:

dataset = gdal.Open(dir)

print(dataset.RasterCount)
img = dataset.GetRasterBand(1).ReadAsArray() 

f = plt.figure() 
plt.imshow(img) 
print(img.shape)
plt.show() 

and here is how I am using the data loader and displaying the thresholded image:

data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
        ]),
    }

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=2) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

for image in dataloders["train"]:
  f = plt.figure() 
  print(image[0].shape)
  plt.imshow(image[0].squeeze()[0,:,:]) 
  plt.show() 
  break

Any help on an alternative way to display the images or any mistakes I am making would be greatly appreciated.

If you want to visualize images loaded by Dataloader, I suggest this script:

for batch in train_data_loader:
    inputs, targets = batch
    for img in inputs:
        image  = img.cpu().numpy()
        # transpose image to fit plt input
        image = image.T
        # normalise image
        data_min = np.min(image, axis=(1,2), keepdims=True)
        data_max = np.max(image, axis=(1,2), keepdims=True)
        scaled_data = (image - data_min) / (data_max - data_min)
        # show image
        plt.imshow(scaled_data)
        plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM