繁体   English   中英

在张量流中为转移学习提供图像数据

[英]Feeding image data in tensorflow for transfer learning

我正在尝试使用tensorflow进行转移学习。 我从教程中下载了预训练模型inception3。 在代码中,用于预测:

prediction = sess.run(softmax_tensor,{'DecodeJpeg/contents:0'}:image_data})

有没有办法喂png图像。 我尝试将DecodeJpeg更改为DecodePng但它不起作用。 除此之外,如果我想要像numpy数组或一批数组那样提供解码图像文件,我应该改变什么?

谢谢!!

classify_image.py使用的InceptionV3图表仅支持开箱即用的JPEG图像。 有两种方法可以将此图表与PNG图像一起使用:

  1. 将PNG图像转换为height x width x 3(通道)Numpy数组,例如使用PIL ,然后输入'DecodeJpeg:0'张量:

     import numpy as np from PIL import Image # ... image = Image.open("example.png") image_array = np.array(image)[:, :, 0:3] # Select RGB channels only. prediction = sess.run(softmax_tensor, {'DecodeJpeg:0': image_array}) 

    也许令人困惑的是, 'DecodeJpeg:0'DecodeJpeg操作的输出 ,因此通过提供此张量,您可以提供原始图像数据。

  2. tf.image.decode_png() op添加到导入的图形中。 简单地将馈送张量的名称从'DecodeJpeg/contents:0'切换到'DecodePng/contents:0'不起作用,因为出货图中没有'DecodePng'操作。 您可以使用tf.import_graph_def()input_map参数将此类节点添加到图形中:

     png_data = tf.placeholder(tf.string, shape=[]) decoded_png = tf.image.decode_png(png_data, channels=3) # ... graph_def = ... softmax_tensor = tf.import_graph_def( graph_def, input_map={'DecodeJpeg:0': decoded_png}, return_elements=['softmax:0']) sess.run(softmax_tensor, {png_data: ...}) 

以下代码应处理这两种情况。

import numpy as np
from PIL import Image

image_file = 'test.jpeg'
with tf.Session() as sess:

    #     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    if image_file.lower().endswith('.jpeg'):
        image_data = tf.gfile.FastGFile(image_file, 'rb').read()
        prediction = sess.run('final_result:0', {'DecodeJpeg/contents:0': image_data})
    elif image_file.lower().endswith('.png'):
        image = Image.open(image_file)
        image_array = np.array(image)[:, :, 0:3]
        prediction = sess.run('final_result:0', {'DecodeJpeg:0': image_array})

    prediction = prediction[0]    
    print(prediction)

或带有直接字符串的较短版本:

image_file = 'test.png' # or 'test.jpeg'
image_data = tf.gfile.FastGFile(image_file, 'rb').read()
ph = tf.placeholder(tf.string, shape=[])

with tf.Session() as sess:        
    predictions = sess.run(output_layer_name, {ph: image_data} )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM