简体   繁体   中英

How to feed Cifar10 trained model with my own image and get label as output?

I am trying to use the trained model based on the Cifar10 tutorial and would like to feed it with an external image 32x32 (jpg or png).
My goal is to be able to get the label as an output . In other words, I want to feed the Network with a single jpeg image of size 32 x 32, 3 channels with no label as an input and have the inference process give me the tf.argmax(logits, 1) .
Basically I would like to be able to use the trained cifar10 model on an external image and see what class it will spit out.

I have been trying to do that based on the Cifar10 Tutorial and unfortunately always have issues. especially with the Session concept and the batch concept.

Any help doing that with Cifar10 would be greatly appreciated.

Here is the implemented code so far with compilation issues :

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import math
import time

import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
import tensorflow as tf

import cifar10
import cifar10_input
import os
import faultnet_flags
from PIL import Image

FLAGS = tf.app.flags.FLAGS

def evaluate():

  filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputImage.jpg'])

  reader = tf.WholeFileReader()
  key, value = reader.read(filename_queue)

  input_img = tf.image.decode_jpeg(value)

  init_op = tf.initialize_all_variables()

# Problem in here with Graph / session
  with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1): 
      image = input_img.eval()

    print(image.shape)
    Image.fromarray(np.asarray(image)).show()

# Problem in here is that I have only one image as input and have no label and would like to have
# it compatible with the Cifar10 network
    reshaped_image = tf.cast(image, tf.float32)
    height = FLAGS.resized_image_size
    width = FLAGS.resized_image_size
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, width, height)
    float_image = tf.image.per_image_whitening(resized_image)  # reshaped_image
    num_preprocess_threads = 1
    images = tf.train.batch(
      [float_image],
      batch_size=128,
      num_threads=num_preprocess_threads,
      capacity=128)
    coord.request_stop()
    coord.join(threads)

    logits = faultnet.inference(images)

    # Calculate predictions.
    #top_k_predict_op = tf.argmax(logits, 1)

    # print('Current image is: ')
    # print(top_k_predict_op[0])

    # this does not work since there is a problem with the session
    # and the Graph conflicting
    my_classification = sess.run(tf.argmax(logits, 1))

    print ('Predicted ', my_classification[0], " for your input image.")


def main(argv=None):
  evaluate()

if __name__ == '__main__':
  tf.app.run() '''

Some basics first:

  1. First you define your graph: image queue, image preprocessing, inference of the convnet, top-k accuracy
  2. Then you create a tf.Session() and work inside it: starting the queue runners, and calls to sess.run()

Here is what your code should look like

# 1. GRAPH CREATION 
filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputImage.jpg'])
...  # NO CREATION of a tf.Session here
float_image = ...
images = tf.expand_dims(float_image, 0)  # create a fake batch of images (batch_size=1)
logits = faultnet.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=5)

# 2. TENSORFLOW SESSION
with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    top_indices = sess.run([top_k_pred])
    print ("Predicted ", top_indices[0], " for your input image.")

EDIT:

As @mrry suggests, if you only need to work on a single image, you can remove the queue runners:

# 1. GRAPH CREATION
input_img = tf.image.decode_jpeg(tf.read_file("/home/.../your_image.jpg"), channels=3)
reshaped_image = tf.image.resize_image_with_crop_or_pad(tf.cast(input_img, width, height), tf.float32)
float_image = tf.image.per_image_withening(reshaped_image)
images = tf.expand_dims(float_image, 0)  # create a fake batch of images (batch_size = 1)
logits = faultnet.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=5)

# 2. TENSORFLOW SESSION
with tf.Session() as sess:
  sess.run(init_op)

  top_indices = sess.run([top_k_pred])
  print ("Predicted ", top_indices[0], " for your input image.")

The original source code in cifar10_eval.py can also be used for testing own individual images as it is shown in the following console output

nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar10_eval.py --run_once True 2>/dev/null
[ -0.63916457  -3.31066918   2.32452989   1.51062226  15.55279636
-0.91585422   1.26451302  -4.11891603  -7.62230825  -4.29096413]
deer
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar2bin.py matchbox.png input.bin 
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar10_eval.py --run_once True 2>/dev/null
[ -1.30562115  12.61497402  -1.34208572  -1.3238833   -6.13368177
-1.17441642  -1.38651907  -4.3274951    2.05489922   2.54187846]
automobile
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ 

and code snippet

#while step < num_iter and not coord.should_stop():
# predictions = sess.run([top_k_op])
print(sess.run(logits[0]))
classification = sess.run(tf.argmalogits[0], 0))
cifar10classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(cifar10classes[classification])

#true_count += np.sum(predictions)
step += 1

# Compute precision @ 1.
precision = true_count / total_sample_count
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

More details can be found in the post How can I test own image to Cifar-10 tutorial on Tensorflow?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM