简体   繁体   English

TF2.0 中的 saved_model.prune()

[英]saved_model.prune() in TF2.0

I am trying to prune nodes of a SavedModel that was generated with tf.keras.我正在尝试修剪使用SavedModel生成的 SavedModel 的节点。 The pruning script is as follows:剪枝脚本如下:

svmod = tf.saved_model.load(fn) #version 1
#svmod = tfk.experimental.load_from_saved_model(fn) #version 2
feeds = ['foo:0']
fetches = ['bar:0']
svmod2 = svmod.prune(feeds=feeds, fetches=fetches)
tf.saved_model.save(svmod2, '/tmp/saved_model/') #version 1
#tfk.experimental.export_saved_model(svmod2, '/tmp/saved_model/') #version 2

If I use version #1 pruning works but gives ValueError: Expected a Trackable object for export when saving.如果我使用版本 #1 修剪工作,但在保存时给出ValueError: Expected a Trackable object for export In version 2, there is no prune() method.在版本 2 中,没有 prune() 方法。

How can I prune a TF2.0 Keras SavedModel?如何修剪 TF2.0 Keras SavedModel?

It looks like the way you are pruning the model in version 1 is fine;看起来您在版本 1 中修剪模型的方式很好; according to your error message, the resulting pruned model cannot be saved because it is not "trackable", which is a necessary condition for saving a model with tf.saved_model.save .根据您的错误消息,无法保存生成的修剪模型,因为它不可“跟踪”,这是使用tf.saved_model.save保存模型的必要条件。 One way to make a trackable object is to inherit from the tf.Module class, as described in the guides for using the SavedModel format and concrete functions .制作可跟踪对象的一种方法是从tf.Module类继承,如使用 SavedModel 格式具体函数的指南中所述 Below is an example of trying to save a tf.function object (which fails because the object is not trackable), inheriting from tf.module , and saving the resulting object:下面是一个尝试保存tf.function对象(由于对象不可跟踪而失败)、从tf.module继承并保存结果对象的tf.module

(Using Python version 3.7.6, TensorFlow version 2.1.0, and NumPy version 1.18.1) (使用 Python 3.7.6 版、TensorFlow 2.1.0 版和 NumPy 1.18.1 版)

import tensorflow as tf, numpy as np

# Define a random TensorFlow function and generate a reference output
conv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)
@tf.function
def conv_model(x):
    return tf.nn.conv2d(x, conv_filter, 1, "SAME")

input_tensor = tf.ones([1, 2, 3, 4])
output_tensor = conv_model(input_tensor)
print("Original model outputs:", output_tensor, sep="\n")

# Try saving the model: it won't work because a tf.function is not trackable
export_dir = "./tmp/"
try: tf.saved_model.save(conv_model, export_dir)
except ValueError: print(
    "Can't save {} object because it's not trackable".format(type(conv_model)))

# Now define a trackable object by inheriting from the tf.Module class
class MyModule(tf.Module):
    @tf.function
    def __call__(self, x): return conv_model(x)

# Instantiate the trackable object, and call once to trace-compile a graph
module_func = MyModule()
module_func(input_tensor)
tf.saved_model.save(module_func, export_dir)

# Restore the model and verify that the outputs are consistent
restored_model = tf.saved_model.load(export_dir)
restored_output_tensor = restored_model(input_tensor)
print("Restored model outputs:", restored_output_tensor, sep="\n")
if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):
    print("Outputs are consistent :)")
else: print("Outputs are NOT consistent :(")

Console output:控制台输出:

Original model outputs:
tf.Tensor(
[[[[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]

  [[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Can't save <class 'tensorflow.python.eager.def_function.Function'> object
because it's not trackable
Restored model outputs:
tf.Tensor(
[[[[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]

  [[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Outputs are consistent :)

Therefore you should try modifying your code as follows:因此,您应该尝试如下修改您的代码:

svmod = tf.saved_model.load(fn) #version 1
svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])

class Exportable(tf.Module):
    @tf.function
    def __call__(self, model_inputs): return svmod2(model_inputs)

svmod2_export = Exportable()
svmod2_export(typical_input)    # call once with typical input to trace-compile
tf.saved_model.save(svmod2_export, '/tmp/saved_model/')

If you don't want to inherit from tf.Module , you can alternatively just instantiate a tf.Module object and add a tf.function method/callable attribute by replacing that section of code as follows:如果您不想从tf.Module继承,您也可以仅实例化tf.Module对象并通过替换该部分代码来添加tf.function方法/可调用属性,如下所示:

to_export = tf.Module()
to_export.call = tf.function(conv_model)
to_export.call(input_tensor)
tf.saved_model.save(to_export, export_dir)

restored_module = tf.saved_model.load(export_dir)
restored_func = restored_module.call

As you can prune successfully in version #1 , I suggest you to try out 'pickle' to save the model.由于您可以在版本 #1 中成功修剪,我建议您尝试使用“pickle”来保存模型。 Try the below steps to save the model.尝试以下步骤来保存模型。

import pickle
with open('<model_name.pkl>', 'wb') as f:
    pickle.dump(<your_model>, f)

Read the model as:将模型读为:

with open('<model_name.pkl>', 'rb') as f:
    model = pickle.load(f)

In your case, for version #1, your_model inside the code snippet is svmod2 .在你的情况下,对于版本#1,代码段内your_modelsvmod2。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM