简体   繁体   中英

Tensorflow: transfer learning from vgg16 .tfmodel file

I'm trying to make a TF implementation of an image classifier (with py3.5 and Windows 10, TF 0.12), so I'm re-using existing models as described here but without all the weird Bazel stuff. After fixing a py2-to-3 bug on this line (wrapping the keys() in list() ), it ran nicely on my 10 folders of different categories. However, the performance is lacking; the training success rate is around 83% and the validation set is never above 60% at best. So I'd like to do some transfer learning from a vgg16 model (which is one I've used before in Caffe/ubuntu); one I've found is here ready to be downloaded.

My question now is, how do you load a .tfmodel file in Tensorflow? The script is expecting a tar.gz to be downloaded, fair enough. It apparently contains a file called classify_image_graph_def.pb , which is not a .tfmodel file. Looking in some example code I see that it's pretty easy to load a .tfmodel file, so I've modified the create_inception_graph function to point straight at the vgg16-20160129.tfmodel file. Upon running this, I get this error:

File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 450, in import_graph_def
    ret.append(name_to_op[operation_name].outputs[output_index])
KeyError: 'pool_3/_reshape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "retrain.py", line 995, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "retrain.py", line 713, in main
    create_inception_graph())
  File "retrain.py", line 235, in create_inception_graph
    RESIZED_INPUT_TENSOR_NAME]))
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 453, in import_graph_def
    'Requested return_element %r not found in graph_def.' % name)
ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

And this is the loading code:

def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.
  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Session() as sess:
    #model_filename = os.path.join(
    #    FLAGS.model_dir, 'classify_image_graph_def.pb')
    model_filename = os.path.join(
        FLAGS.model_dir, 'vgg16-20160129.tfmodel')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor

Something seems to be going awry in the tf.import_graph_def call but there's no documentation for that function, weirdly. Is what I'm trying even possible? There's a whole bunch of bottleneck tensor and jpeg data and resized input tensor names that I don't know what they're there for, which the example doesn't replicate.

Like it says in the trace, there is a ValueError.

ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

Your graph file - 'vgg16-20160129.tfmodel' does not have this node in it. Re-check the variables BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME, RESIZED_INPUT_TENSOR_NAME. These should correspond to the architecture of your network.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM