繁体   English   中英

我如何生成清晰的图像Cifar-10

[英]How do I generate sharp images Cifar-10

我正在使用tensorflow并尝试在Cifar-10上可视化自动编码器的输入/输出。

我在这里遵循以下答案: 为什么使用matplotlib无法正确显示CIFAR-10图像?

这是经过稍加修改即可运行其代码的结果(将figsize更改为5,5):

可视化图像

但是,这仍然没有原始页面中的图像清晰和清晰: https : //www.cs.toronto.edu/~kriz/cifar.html

我该如何做得更好?

这里可能有两个问题:

问题1:

您的颜色通道(红色,绿色,蓝色)似乎混合在一起。 那可以解释为什么颜色如此怪异。 如果是这种情况,您将需要交换阵列中的颜色通道,如下所示。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cbook import get_sample_data

rgb_image = plt.imread(get_sample_data("grace_hopper.png", asfileobj=False))

# correct color channels (R, G, B)
plt.figure()
plt.imshow(rgb_image)
plt.axis('off')

本色

# swapped color channels (R, B, G)
rgb_image = rgb_image[:, :, [0, 2, 1]]
plt.figure()
plt.imshow(rgb_image)
plt.axis('off')

交换颜色通道

问题2:

Matplotlib的plt.imshow具有关键字参数interpolation ,如果未指定,则默认为None 然后,Matplotlib会参考您的本地样式表来确定默认的插值行为。 根据您的样式表,这可能会导致应用插值,从而导致图像失真。 有关更多详细信息,请参见imshow文档

如果要保证Matplotlib不对图像进行插值,则应在plt.imshow指定plt.imshow interpolation="none" 这很混乱,因为默认的NoneType值None与字符串值"none"产生不同的行为。

red = np.zeros((100, 100, 3), dtype=np.uint8)
red[:, :, 0] = 255
red[40:60, 40:60, :] = 255

# with interpolation
plt.figure()
plt.imshow(red, interpolation='bicubic') 
plt.axis('off')

带插值

# without interpolation
plt.figure()
plt.imshow(red, interpolation='none') 
plt.axis('off')

没有插值

也许您应该做这样的事情。 图像非常小,高度和宽度均为32px,因此只有在缩略图大小时它们才会更清晰。 我在这里使用双三次变换对其进行了插值。 但是您可以将其更改为“无”,这样就可以得到像素化图像而不是模糊图像。

def unpickle(file):
    with open(file, 'rb') as fo:
        dict1 = pickle.load(fo, encoding='bytes')
    return dict1

pd_tr = pd.DataFrame()
tr_y = pd.DataFrame()

for i in range(1,6):
    data = unpickle('data/data_batch_' + str(i))
    pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
    tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
    pd_tr['labels'] = tr_y

tr_x = np.asarray(pd_tr.iloc[:, :3072])
tr_y = np.asarray(pd_tr['labels'])
ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
labels = unpickle('data/batches.meta')[b'label_names']

def plot_CIFAR(ind):
    arr = tr_x[ind]
    R = arr[0:1024].reshape(32,32)/255.0
    G = arr[1024:2048].reshape(32,32)/255.0
    B = arr[2048:].reshape(32,32)/255.0

    img = np.dstack((R,G,B))
    title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)
    ax.imshow(img,interpolation='bicubic')
    ax.set_title('Category = '+ title,fontsize =15)

plot_CIFAR(4)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM