簡體   English   中英

如何計算從PB文件加載的Tensorflow模型的觸發器

[英]How to calculate the flops of a tensorflow model loaded from pb file

我有一個保存在PB文件中的模型。 我希望能計算出它的觸發器。 我的示例代碼如下:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

pb_file = 'themodel.pb'

run_meta = tf.RunMetadata()
with tf.Session() as sess:
    print("load graph")
    with gfile.FastGFile(pb_path,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

打印信息很奇怪。 我的模型有幾十層,但是在打印的信息中只報告了18個觸發器。 我非常確定模型已正確加載,因為如果嘗試按以下方式打印每個圖層的名稱,則:

print([n.name for n in tf.get_default_graph().as_graph_def().node])

打印信息顯示正確的網絡。

我的代碼有什么問題?

謝謝!

我想我找到了問題的原因和解決方案。 以下代碼可以打印給定pb文件的觸發器。

import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

pb_path = 'mymodel.pb'

run_meta = tf.RunMetadata()
with tf.Graph().as_default():
    output_graph_def = graph_pb2.GraphDef()
    with open(pb_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")
        print('model loaded!')
    all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
    # for k in all_keys:
    #   print(k)

    with tf.Session() as sess:
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

問題中打印的觸發器只有18個的原因是,在生成pb文件時,我將輸入圖像的形狀設置為[None, None, 3] 如果我將其更改為[500, 500, 3] ,那么印刷的拖鞋將是正確的。

不知道在不知道輸入和輸出的情況下如何計算性能指標:也許它需要CallableOptions 我會使用trace_next_step和一個Session而不是手動計算它們。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM