简体   繁体   English

显示使用 pytorch 数据加载器加载的图像

[英]Displaying images loaded with pytorch dataloader

I am working with some lidar data images that I cannot post here due to a reputation restriction on posting images.我正在处理一些由于发布图像的声誉限制而无法在此处发布的激光雷达数据图像。 However, when loading the same images using pytorch ImageFolder and Dataloader with the only transform being converting the images to tensors there seems to be some extreme thresholding and I can't seem to locate the cause of this.但是,当使用 pytorch ImageFolder 和 Dataloader 加载相同的图像时,唯一的转换是将图像转换为张量,似乎存在一些极端的阈值,我似乎无法找到原因。

Below is how I'm displaying the first image:下面是我显示第一张图片的方式:

dataset = gdal.Open(dir)

print(dataset.RasterCount)
img = dataset.GetRasterBand(1).ReadAsArray() 

f = plt.figure() 
plt.imshow(img) 
print(img.shape)
plt.show() 

and here is how I am using the data loader and displaying the thresholded image:这是我使用数据加载器并显示阈值图像的方式:

data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
        ]),
    }

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=2) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

for image in dataloders["train"]:
  f = plt.figure() 
  print(image[0].shape)
  plt.imshow(image[0].squeeze()[0,:,:]) 
  plt.show() 
  break

Any help on an alternative way to display the images or any mistakes I am making would be greatly appreciated.任何有关显示图像的替代方式的帮助或我所犯的任何错误将不胜感激。

If you want to visualize images loaded by Dataloader, I suggest this script:如果您想可视化 Dataloader 加载的图像,我建议使用此脚本:

for batch in train_data_loader:
    inputs, targets = batch
    for img in inputs:
        image  = img.cpu().numpy()
        # transpose image to fit plt input
        image = image.T
        # normalise image
        data_min = np.min(image, axis=(1,2), keepdims=True)
        data_max = np.max(image, axis=(1,2), keepdims=True)
        scaled_data = (image - data_min) / (data_max - data_min)
        # show image
        plt.imshow(scaled_data)
        plt.show()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM