繁体   English   中英

Tensorflow:从vgg16 .tfmodel文件转移学习

[英]Tensorflow: transfer learning from vgg16 .tfmodel file

我正在尝试对图像分类器进行TF实现(使用py3.5和Windows 10,TF 0.12),因此我将按此处所述重用现有模型但没有所有奇怪的Bazel内容。 修复此行上的py2至3错误(将keys()包装在list() )后,它可以很好地在我的10个不同类别的文件夹上运行。 但是,缺乏性能。 培训成功率大约为83%,而验证集的充其量也永远不会超过60%。 因此,我想从vgg16模型(这是我之前在Caffe / ubuntu中使用过的模型)进行一些转移学习; 我找到的一个可以在这里下载了。

现在我的问题是,如何在Tensorflow中加载.tfmodel文件? 该脚本期望下载tar.gz,足够公平。 它显然包含一个名为classify_image_graph_def.pb的文件,它不是.tfmodel文件。 查看一些示例代码,我发现加载.tfmodel文件非常容易,因此我修改了create_inception_graph函数,使其直接指向vgg16-20160129.tfmodel文件。 运行此命令后,出现以下错误:

File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 450, in import_graph_def
    ret.append(name_to_op[operation_name].outputs[output_index])
KeyError: 'pool_3/_reshape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "retrain.py", line 995, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "retrain.py", line 713, in main
    create_inception_graph())
  File "retrain.py", line 235, in create_inception_graph
    RESIZED_INPUT_TENSOR_NAME]))
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 453, in import_graph_def
    'Requested return_element %r not found in graph_def.' % name)
ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

这是加载代码:

def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.
  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Session() as sess:
    #model_filename = os.path.join(
    #    FLAGS.model_dir, 'classify_image_graph_def.pb')
    model_filename = os.path.join(
        FLAGS.model_dir, 'vgg16-20160129.tfmodel')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor

tf.import_graph_def调用中似乎有些问题,但是很奇怪,没有该函数的文档。 我正在尝试的可能吗? 有很多瓶颈张量和jpeg数据以及调整了输入张量名称的大小,但我不知道它们的用途,该示例无法复制。

就像在跟踪中说的那样,存在ValueError。

ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

您的图形文件-'vgg16-20160129.tfmodel'中没有此节点。 重新检查变量BOTTLENECK_TENSOR_NAME,JPEG_DATA_TENSOR_NAME,RESIZED_INPUT_TENSOR_NAME。 这些应该与您的网络体系结构相对应。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM