[英]Plot augmented images using matplotlib in python3.6
我正在嘗試從訓練目錄中繪制一堆增強圖像。 我正在使用Keras和Tensorflow。 可視庫是matplotlib。 我正在使用下面的代碼在6行和6列中繪制256 X 256 X 1
灰度圖像。 我得到的錯誤是
Invalid Dimensions for image data.
這是我正在使用的代碼:-
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import keras
from keras.preprocessing.image import ImageDataGenerator
train_set = '/home/ai/IPI/Data/v1_single_model/Train/' # Use your own path
batch_size = 4
gen = ImageDataGenerator(rescale = 1. / 255)
train_batches = gen.flow_from_directory(
'data/train',
target_size=(256, 256),
batch_size=batch_size,
class_mode='binary')
def plot_images(img_gen, img_title):
fig, ax = plt.subplots(6,6, figsize=(10,10))
plt.suptitle(img_title, size=32)
plt.setp(ax, xticks=[], yticks=[])
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
for (img, label) in img_gen:
for i in range(6):
for j in range(6):
if i*6 + j < 256:
ax[i][j].imshow(img[i*6 + j])
break
plot_images(train_batches, "Augmented Images")
下面是錯誤和python追溯的快照:-
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-79-81bdb7f0d12e> in <module>()
----> 1 plot_images(train_batches, "Augmented Images")
<ipython-input-78-d1d4bba983d3> in plot_images(img_gen, img_title)
8 for j in range(6):
9 if i*6 + j < 32:
---> 10 ax[i][j].imshow(img[i*6 + j])
11 break
~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
1896 warnings.warn(msg % (label_namer, func.__name__),
1897 RuntimeWarning, stacklevel=2)
-> 1898 return func(ax, *args, **kwargs)
1899 pre_doc = inner.__doc__
1900 if pre_doc is None:
~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
5122 resample=resample, **kwargs)
5123
-> 5124 im.set_data(X)
5125 im.set_alpha(alpha)
5126 if im.get_clip_path() is None:
~/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A)
598 if (self._A.ndim not in (2, 3) or
599 (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
--> 600 raise TypeError("Invalid dimensions for image data")
601
602 self._imcache = None
TypeError: Invalid dimensions for image data
我究竟做錯了什么 ?
錯誤告訴您出了什么問題。 您的圖像的形狀為(1,n,m,1)
,在第一個循環運行中,選擇img[0]
,這將導致數組的形狀為(n,m,1)
self._A.ndim == 3 and self._A.shape[-1] not in (3, 4)
來自matplotlib.pyplot.imshow(X, ...)
文檔
X
:類數組,形狀(n,m)或(n,m,3)或(n,m,4)
但不是(n,m,1)
。 除此之外,只要i*6 + j > 0
, img[i*6 + j]
就會失敗。
圖像img
尺寸為(samples, height, width, channels)
。 img
是單個樣本,因此samples = 1
; 它是灰度的,因此channels = 1
。 要獲得形狀(n,m)
的圖像,您需要像選擇它
imshow(img[0,:,:,0])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.