繁体   English   中英

重新训练inception-v3的最后一层会大大减慢分类

[英]retraining last layer of inception-v3 significantly slowers the classification

为了尝试使用TF和PY3.5在Inception-v3上进行转移学习,我测试了两种方法:

1-重新训练最后一层,如下所示: https : //github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/image_retraining

2-在Inception-V3瓶颈之上应用线性SVM,如此处所示: https : //www.kernix.com/blog/image-classification-with-a-pre-trained-deep-neural-network_p11

可以预期的是,由于关键部分(瓶颈提取)是相同的,因此他们应该在分类阶段具有类似的运行时。 但实际上,进行分类时,经过重新训练的网络速度要慢大约8倍。

我的问题是是否有人对此有想法。

一些代码片段:

SVM位于顶部(速度更快):

def getTensors():
    graph_def = tf.GraphDef()
    f = open('classify_image_graph_def.pb', 'rb')
    graph_def.ParseFromString(f.read())
    tensorBottleneck, tensorsResizedImage = tf.import_graph_def(graph_def, name='', return_elements=['pool_3/_reshape:0', 'Mul:0'])
    return tensorBottleneck, tensorsResizedImage 

def calc_bottlenecks(imgFile, tensorBottleneck, tensorsResizedImage):
    """ - read, decode and resize to get <resizedImage> - """
    bottleneckValues = sess.run(tensorBottleneck, {tensorsResizedImage : resizedImage})
    return np.squeeze(bottleneckValues)

在我的(Windows)笔记本电脑上,这大约需要0.5秒,而SVM部分不需要任何时间。

重新训练最后一层-(由于代码较长,因此很难总结)

def loadGraph(pbFile):
    with tf.gfile.FastGFile(pbFile, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    with tf.Session() as sess:
        softmaxTensor = sess.graph.get_tensor_by_name('final_result:0')

def labelImage(imageFile, softmaxTensor):
    with tf.Session() as sess:
        input_layer_name = 'DecodeJpeg/contents:0'
        predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})

“ pbFile”是保存为再训练器的文件,假定具有相同的拓扑和权重(不包括分类层),则为“ classify_image_graph_def.pb”。 运行大约需要4秒钟(在我的同一台笔记本电脑上,没有加载)。

对性能差距有任何想法吗? 谢谢!

解决了。 问题在于为每个图像创建一个新的tf.Session()。 在读取图形并使用它时存储会话使运行时恢复到预期的状态。

def loadGraph(pbFile):
    ...
    with tf.Session() as sess:
        softmaxTensor = sess.graph.get_tensor_by_name('final_result:0')
        sessToStore = sess
    return softmaxTensor, sessToStore  

def labelImage(imageFile, softmaxTensor, sessToStore):
    input_layer_name = 'DecodeJpeg/contents:0'
    predictions, = sessToStore.run(softmax_tensor, {input_layer_name: image_data})

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM