繁体   English   中英

plt.imshow()给出TypeError:无法使用PyTorch将dtype对象的图像数据转换为float

[英]plt.imshow() gives TypeError: Image data of dtype object cannot be converted to float using PyTorch

我正在尝试为一组图像创建自定义数据集处理器。 但是,当我尝试查看数据集中的图像时,遇到TypeError的错误:dtype对象的图像数据无法转换为float。

我试图检查是否将PIL图像传递到plt.imshow()函数中。

class DatasetProcessing(Dataset):

    def __init__(self, input_data, output_data, transform=None): 
        self.transform = transform
        self.input_data = 
        input_data.reshape((-1,64,64)).astype(np.float32)[:,:,:,None]
        self.output_data = output_data 

    def __getitem__(self, index): 
        return self.transform(self.input_data[index]), self.output_data[index]

    def __len__(self): 
        return len(list(self.input_data))

transform = transforms.Compose([transforms.ToPILImage()])


dset_train = DatasetProcessing(X_slices_train, Y_train, transform)


train_loader = torch.utils.data.DataLoader(dset_train, batch_size=4, 
                                      shuffle=True, num_workers=4) 

plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
    plt.subplot(1,6,num+1)
    plt.axis('off')
    print(x)
    plt.imshow(np.asarray(x))
    plt.title(y_train[num])

我希望得到我的数据集的图片,但是却收到以下错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-8b8caac49d97> in <module>
      4     plt.axis('off')
      5     print(x)
----> 6     plt.imshow(np.asarray(x))
      7     plt.title(y_train[num])

~/anaconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, data, **kwargs)
   2675         filternorm=filternorm, filterrad=filterrad, imlim=imlim,
   2676         resample=resample, url=url, **({"data": data} if data is not
-> 2677         None else {}), **kwargs)
   2678     sci(__ret)
   2679     return __ret

~/anaconda3/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1587     def inner(ax, *args, data=None, **kwargs):
   1588         if data is None:
-> 1589             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1590 
   1591         bound = new_sig.bind(ax, *args, **kwargs)

~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
    367                 f"%(removal)s.  If any parameter follows {name!r}, they "
    368                 f"should be pass as keyword, not positionally.")
--> 369         return func(*args, **kwargs)
    370 
    371     return wrapper

~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
    367                 f"%(removal)s.  If any parameter follows {name!r}, they "
    368                 f"should be pass as keyword, not positionally.")
--> 369         return func(*args, **kwargs)
    370 
    371     return wrapper

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
   5658                               resample=resample, **kwargs)
   5659 
-> 5660         im.set_data(X)
   5661         im.set_alpha(alpha)
   5662         if im.get_clip_path() is None:

~/anaconda3/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
    676                 not np.can_cast(self._A.dtype, float, "same_kind")):
    677             raise TypeError("Image data of dtype {} cannot be converted to "
--> 678                             "float".format(self._A.dtype))
    679 
    680         if not (self._A.ndim == 2

TypeError: Image data of dtype object cannot be converted to float

如果正确理解self.transform(self.input_data[index]), self.output_data[index]您的dset_train产生self.transform(self.input_data[index]), self.output_data[index]self.transform(self.input_data[index])是图像张量(数据)而self.output_data[index]是标签,但在这里:

plt.imshow(np.asarray(x))

您正在传递未包装的x ,它实际上是(数据,标签)

因此,您需要先将其打开包装:

plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
    data, label = x
    plt.subplot(1,6,num+1)
    plt.axis('off')
    print(x)
    plt.imshow(np.asarray(data))
    plt.title(y_train[num])

编辑:

为什么我必须打开x的包装?

您继承自PyTorch的Dataset ,并且根据docs

代表从键到数据样本的映射的所有数据集都应将其子类化。 所有子类都应覆盖__getitem__() ,以支持获取给定键的数据样本。

在您定义的DatasetProcessing类中, __getitem__() DatasetProcessing __getitem__()返回一个包含2个项的元组: self.transform(self.input_data[index])self.output_data[index] ,第一个是数据,第二个是适当的标签。 这就是为什么您需要像data, y = x将其DatasetProcessing ,因为您的DatasetProcessing数据DatasetProcessing产生数据和标签。

您可以将我链接到任何文档/教程吗?

我可以向您推荐此链接:

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM