繁体   English   中英

由于形状不完整,计算 keras model 的 FLOPS 返回没有 flops 的操作

[英]Calculating FLOPS of a keras model returns ops with no flops due to incomplete shapes

我正在尝试计算我的 model 的 FLOPS,这是一个 tf.keras model。

As a workaround I am dealing with my model as being a pure tensorflow one, since I am not aware of a way to calculate FLOPS directly in a keras model.

我面临的问题是(显然)在某些层上,形状被认为是未定义的,我得到了一个错误。

import tensorflow as tf
import numpy as np

model = tf.keras.applications.ResNet50(
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    input_shape=None,
    pooling=None,
    classes=1000)
nparams = np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])
options = tf.profiler.ProfileOptionBuilder.float_operation()
options['output'] = 'none'
flops = tf.profiler.profile(tf.get_default_graph(), options=options).total_float_ops
flops = flops // 2

由于形状不完整,111 次操作没有失败统计。

另一方面,如果我查看之前 model 的摘要,除了批量大小之外,我似乎在层中找不到任何未定义的形状。 而且我认为我无法明确定义批量大小。

model.summary()
 Model: "resnet50"

input_1 (InputLayer) [(无, 224, 224, 3) 0
...

问题是,当我得到它时,返回的 FLOPS 不准确。 那么,我怎样才能得到我的 model 的实际 FLOPS?

我的 tensorflow 是 1.15,Keras 是 2.2.5,Keras-Applications 是 1.0.8

经过一番研究,我终于设法找到了解决方案。 对此的一些观察:

1)这里的问题似乎是探查器的这个None足以导致这些错误。 应使用硬编码形状调用 model,例如:

ResNet50(include_top=True, weights="imagenet", input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3)), input_shape=None, pooling=None, classes=1000)

该解决方案似乎仅对 tensorflow < 2 有效。在 tf 2.0+ 中使用它的解决方法是:

def get_flops(model_h5_path):
    session = tf.compat.v1.Session()
    graph = tf.compat.v1.get_default_graph()
        

    with graph.as_default():
        with session.as_default():
            model = tf.keras.models.load_model(model_h5_path)

            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        
            # We use the Keras session graph in the call to the profiler.
            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)
        
            return flops.total_float_ops

取自这里

3) 实际解决方案仅适用于冻结模型。 好消息是,这就是所有工作首先测量它的方式(准确地说是通过冻结 model 的推断)。 因此,一个可行的解决方案是:

import keras.backend as K
from keras.applications.resnet50 import ResNet50


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


run_meta = tf.RunMetadata()
with tf.Session(graph=tf.Graph()) as sess:
    K.set_session(sess)
    net = ResNet50(include_top=True, weights="imagenet", input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3)), input_shape=None, pooling=None, classes=1000)

    frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in net.outputs])
    with tf.gfile.GFile('graph.pb', "wb") as f:
        f.write(frozen_graph.SerializeToString())

    g2 = load_pb('./graph.pb')
    with g2.as_default():
        flops = tf.profiler.profile(g2, options=tf.profiler.ProfileOptionBuilder.float_operation())
        print('FLOP after freezing {} MFLOPS'.format(float(flops.total_float_ops//2) * 1e-6))

最后:

冻结后的 FLOP 80.87084 MFLOPS

它设法计算冻结的 model 的 FLOPS 并创建磁盘上保存的 pb model ( graph.pb ) 的副产品(当然可以在之后删除)。

该解决方案大量借鉴了这些答案的代码(公平起见)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM