简体   繁体   English

将 Tensorflow Keras model(编码器 - 解码器)保存为 SavedModel 格式

[英]Saving a Tensorflow Keras model (Encoder - Decoder) to SavedModel format

I've written (with help from TF tutorials) an image captioning model which uses an encoder-decoder model with attention.我已经(在 TF 教程的帮助下)编写了一个图像字幕 model,它使用了编码器-解码器 model,并引起了注意。

Now, I want to convert it to TfLite and eventually deploy it in Flutter.现在,我想将其转换为 TfLite 并最终将其部署在 Flutter 中。

I'm trying to save the encoder and decoder models into a SavedModel format, which I can then convert to TfLite.我正在尝试将编码器和解码器模型保存为 SavedModel 格式,然后我可以将其转换为 TfLite。

Attention Model:注意 Model:

    class BahdanauAttention(tf.keras.Model):
        def __init__(self, units):
            super(BahdanauAttention, self).__init__()
            self.W1 = tf.keras.layers.Dense(units)
            self.W2 = tf.keras.layers.Dense(units)
            self.V = tf.keras.layers.Dense(1)

        def call(self, features, hidden):
            # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

            # hidden shape == (batch_size, hidden_size)
            # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
            hidden_with_time_axis = tf.expand_dims(hidden, 1)

            # score shape == (batch_size, 64, hidden_size)
            score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

            # attention_weights shape == (batch_size, 64, 1)
            # you get 1 at the last axis because you are applying score to self.V
            attention_weights = tf.nn.softmax(self.V(score), axis=1)

            # context_vector shape after sum == (batch_size, hidden_size)
            context_vector = attention_weights * features
            context_vector = tf.reduce_sum(context_vector, axis=1)

            return context_vector, attention_weights

Encoder Model:编码器 Model:

    class CNN_Encoder(tf.keras.Model):
        # This encoder passes the extracted features through a Fully connected layer
        def __init__(self, embedding_dim):
            super(CNN_Encoder, self).__init__()
            # shape after fc == (batch_size, 64, embedding_dim)
            self.fc = tf.keras.layers.Dense(embedding_dim)

        @tf.function
        def call(self, x):
            inp = x
            y = self.fc(inp)
            z = tf.nn.relu(y)
            return z

Decoder Model:解码器 Model:

    class RNN_Decoder(tf.keras.Model):
        def __init__(self, embedding_dim, units, vocab_size):
            super(RNN_Decoder, self).__init__()
            self.units = units

            self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
            self.gru = tf.keras.layers.GRU(self.units,
                                           return_sequences=True,
                                           return_state=True,
                                           recurrent_initializer='glorot_uniform')
            self.fc1 = tf.keras.layers.Dense(self.units)
            self.fc2 = tf.keras.layers.Dense(vocab_size)

            self.attention = BahdanauAttention(self.units)

        @tf.function
        def call(self, x, features1, hidden):
            # defining attention as a separate model
            features1 = features1
            hidden1 = hidden
            context_vector, attention_weights = self.attention(features1, hidden1)

            # x shape after passing through embedding == (batch_size, 1, embedding_dim)
            x = self.embedding(x)

            # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
            x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

            # passing the concatenated vector to the GRU
            output, state = self.gru(x)

            # shape == (batch_size, max_length, hidden_size)
            x = self.fc1(output)

            # x shape == (batch_size * max_length, hidden_size)
            x = tf.reshape(x, (-1, x.shape[2]))

            # output shape == (batch_size * max_length, vocab)
            x = self.fc2(x)

            return x, state, attention_weights

        def reset_state(self, batch_size):
            return tf.zeros((batch_size, self.units))

Now, while saving the models, the Encoder gets saved as SavedModel with no errors, but the decoder doesn't.现在,在保存模型时,编码器被保存为 SavedModel,没有错误,但解码器没有。

    tf.saved_model.save(decoder, 'decoder_model', signatures=decoder.call.get_concrete_function(
            [
                tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name='x'), 
                tf.TensorSpec(shape=[1, 64, 256], dtype=tf.float32, name="features1"),
                tf.TensorSpec(shape=[1, 512], dtype=tf.float32, name="hidden"),
            ]
    ))

Error:错误:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-66-da4712d61d18> in <module>
          3             tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name='x'),
          4             tf.TensorSpec(shape=[1, 64, 256], dtype=tf.float32, name="features1"),
    ----> 5             tf.TensorSpec(shape=[1, 512], dtype=tf.float32, name="hidden"),
          6         ]
          7 ))

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in 
    get_concrete_function(self, *args, **kwargs)
        913       # In this case we have created variables on the first call, so we run the
        914       # defunned version which is guaranteed to never create variables.
    --> 915       return self._stateless_fn.get_concrete_function(*args, **kwargs)
        916     elif self._stateful_fn is not None:
        917       # In this case we have not created variables on the first call. So we can

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in 
    get_concrete_function(self, *args, **kwargs)
      2432       args, kwargs = None, None
      2433     with self._lock:
   -> 2434       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
      2435       if self.input_signature:
      2436         args = self.input_signature

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in 
    _maybe_define_function(self, args, kwargs)
       2701 
       2702       self._function_cache.missed.add(call_context_key)
    -> 2703       graph_function = self._create_graph_function(args, kwargs)
       2704       self._function_cache.primary[cache_key] = graph_function
       2705       return graph_function, args, kwargs

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in 
    _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
       2591             arg_names=arg_names,
       2592             override_flat_arg_shapes=override_flat_arg_shapes,
    -> 2593             capture_by_value=self._capture_by_value),
       2594         self._function_attributes,
       2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\func_graph.py in 
    func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, 
    autograph_options, add_control_dependencies, arg_names, op_return_value, collections, 
    capture_by_value, override_flat_arg_shapes)
        976                                           converted_func)
        977 
    --> 978       func_outputs = python_func(*func_args, **func_kwargs)
        979 
        980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in 
    wrapped_fn(*args, **kwds)
        437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
        438         # the function a weak reference to itself to avoid a reference cycle.
    --> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
        440     weak_wrapped_fn = weakref.ref(wrapped_fn)
        441 

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in 
    bound_method_wrapper(*args, **kwargs)
       3209     # However, the replacer is still responsible for attaching self properly.
       3210     # TODO(mdan): Is it possible to do it here instead?
    -> 3211     return wrapped_fn(*args, **kwargs)
       3212   weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
       3213 

    ~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\func_graph.py in 
    wrapper(*args, **kwargs)
        966           except Exception as e:  # pylint:disable=broad-except
        967             if hasattr(e, "ag_error_metadata"):
    --> 968               raise e.ag_error_metadata.to_exception(e)
        969             else:
        970               raise

    TypeError: in converted code:


        TypeError: tf__call() missing 2 required positional arguments: 'features' and 'hidden'

I've spent the last 4 days trying to get around this error, but to no avail:(在过去的 4 天里,我一直在试图解决这个错误,但无济于事:(

Any help on this would be highly appreciated!对此的任何帮助将不胜感激!

Edit:编辑:

I fixed the code as suggested by pallazo train and it fixed this error, but now an other error comes up (problem lies in the attention part):我按照pallazo train的建议修复了代码并修复了这个错误,但现在出现了另一个错误(问题在于注意部分):

WARNING:tensorflow:Skipping full serialization of Keras model <__main__.RNN_Decoder object at 0x0000023F61D37278>, because its inputs are not defined.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-45-4e1cfeda04ea> in <module>
      2             tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name='x'),
      3             tf.TensorSpec(shape=[1, 64, 256], dtype=tf.float32, name="features1"),
----> 4             tf.TensorSpec(shape=[1, 512], dtype=tf.float32, name="hidden"),
      5 ))

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\saved_model\save.py in save(obj, export_dir, signatures, options)
    897   # Note we run this twice since, while constructing the view the first time
    898   # there can be side effects of creating variables.
--> 899   _ = _SaveableView(checkpoint_graph_view)
    900   saveable_view = _SaveableView(checkpoint_graph_view)
    901 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\saved_model\save.py in __init__(self, checkpoint_view)
    163     self.checkpoint_view = checkpoint_view
    164     trackable_objects, node_ids, slot_variables = (
--> 165         self.checkpoint_view.objects_ids_and_slot_variables())
    166     self.nodes = trackable_objects
    167     self.node_ids = node_ids

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\training\tracking\graph_view.py in objects_ids_and_slot_variables(self)
    413       A tuple of (trackable objects, object -> node id, slot variables)
    414     """
--> 415     trackable_objects, path_to_root = self._breadth_first_traversal()
    416     object_names = object_identity.ObjectIdentityDictionary()
    417     for obj, path in path_to_root.items():

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\training\tracking\graph_view.py in _breadth_first_traversal(self)
    197             % (current_trackable,))
    198       bfs_sorted.append(current_trackable)
--> 199       for name, dependency in self.list_dependencies(current_trackable):
    200         if dependency not in path_to_root:
    201           path_to_root[dependency] = (

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\saved_model\save.py in list_dependencies(self, obj)
    107   def list_dependencies(self, obj):
    108     """Overrides a parent method to include `add_object` objects."""
--> 109     extra_dependencies = self.list_extra_dependencies(obj)
    110     extra_dependencies.update(self._extra_dependencies.get(obj, {}))
    111 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\saved_model\save.py in list_extra_dependencies(self, obj)
    134   def list_extra_dependencies(self, obj):
    135     return obj._list_extra_dependencies_for_serialization(  # pylint: disable=protected-access
--> 136         self._serialization_cache)
    137 
    138   def list_functions(self, obj):

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in _list_extra_dependencies_for_serialization(self, serialization_cache)
   2414   def _list_extra_dependencies_for_serialization(self, serialization_cache):
   2415     return (self._trackable_saved_model_saver
-> 2416             .list_extra_dependencies_for_serialization(serialization_cache))
   2417 
   2418   def _list_functions_for_serialization(self, serialization_cache):

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\base_serialization.py in list_extra_dependencies_for_serialization(self, serialization_cache)
     76       of attributes are listed in the `saved_model._LayerAttributes` class.
     77     """
---> 78     return self.objects_to_serialize(serialization_cache)
     79 
     80   def list_functions_for_serialization(self, serialization_cache):

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\layer_serialization.py in objects_to_serialize(self, serialization_cache)
     74   def objects_to_serialize(self, serialization_cache):
     75     return (self._get_serialized_attributes(
---> 76         serialization_cache).objects_to_serialize)
     77 
     78   def functions_to_serialize(self, serialization_cache):

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     93 
     94     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95         serialization_cache)
     96 
     97     serialized_attr.set_and_validate_objects(object_dict)

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     51     objects, functions = (
     52         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 53             serialization_cache))
     54     functions['_default_save_signature'] = default_signature
     55     return objects, functions

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
    102     """Returns dictionary of serialized attributes."""
    103     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
--> 104     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    105     # Attribute validator requires that the default save signature is added to
    106     # function dict, even if the value is None.

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in wrap_layer_functions(layer, serialization_cache)
    198     for fn in fns.values():
    199       if fn is not None and fn.input_signature is not None:
--> 200         fn.get_concrete_function()
    201 
    202   # Restore overwritten functions and losses

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in get_concrete_function(self, *args, **kwargs)
    554   def get_concrete_function(self, *args, **kwargs):
    555     if not self.call_collection.tracing:
--> 556       self.call_collection.add_trace(*args, **kwargs)
    557     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
    558 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in add_trace(self, *args, **kwargs)
    429         trace_with_training(False)
    430       else:
--> 431         fn.get_concrete_function(*args, **kwargs)
    432     self.tracing = False
    433 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in get_concrete_function(self, *args, **kwargs)
    555     if not self.call_collection.tracing:
    556       self.call_collection.add_trace(*args, **kwargs)
--> 557     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
    558 
    559 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in get_concrete_function(self, *args, **kwargs)
    907       if self._stateful_fn is None:
    908         initializers = []
--> 909         self._initialize(args, kwargs, add_initializers_to=initializers)
    910         self._initialize_uninitialized_variables(initializers)
    911 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in _initialize(self, args, kwds, add_initializers_to)
    495     self._concrete_stateful_fn = (
    496         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 497             *args, **kwds))
    498 
    499     def invalid_creator_scope(*unused_args, **unused_kwds):

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2387       args, kwargs = None, None
   2388     with self._lock:
-> 2389       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2390     return graph_function
   2391 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in wrapper(*args, **kwargs)
    532         saving=True):
    533       with base_layer_utils.autocast_context_manager(layer._compute_dtype):  # pylint: disable=protected-access
--> 534         ret = method(*args, **kwargs)
    535     _restore_layer_losses(original_losses)
    536     return ret

~\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
    574   layer_call = _get_layer_call_method(layer)
    575   def call_and_return_conditional_losses(inputs, *args, **kwargs):
--> 576     return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
    577   return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
    578 

TypeError: call() missing 1 required positional argument: 'hidden1'

The arguments of decoder.call.get_concrete_function is not a list. decoder.call.get_concrete_function 的decoder.call.get_concrete_function不是列表。 If you have 3 spec objects you should call by get_concrete_function(s1,s2,s3) instead of get_concrete_function( [s1,s2,s3])如果你有 3 个规范对象,你应该调用get_concrete_function(s1,s2,s3)而不是get_concrete_function( [s1,s2,s3])

Try remove the []尝试删除[]

tf.saved_model.save(decoder, 'decoder_model', signatures=decoder.call.get_concrete_function(
            tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name='x'), 
            tf.TensorSpec(shape=[1, 64, 256], dtype=tf.float32, name="features1"),
            tf.TensorSpec(shape=[1, 512], dtype=tf.float32, name="hidden")
))

I think I found the solution to the other error as well.我想我也找到了另一个错误的解决方案。

In addition to the answer of @palazzo-train, you can also declare the @tf.function signatures in class.除了@palazzo-train 的答案,还可以在class 中声明@tf.function 签名。

For example:例如:

class CNN_Encoder(tf.keras.Model):
        # This encoder passes the extracted features through a Fully connected layer
        def __init__(self, embedding_dim):
            super(CNN_Encoder, self).__init__()
            # shape after fc == (batch_size, 64, embedding_dim)
            self.fc = tf.keras.layers.Dense(embedding_dim)
            
        @tf.function(input_signature = [tf.TensorSpec(shape=[1, 64, features_shape],)])
        def call(self, x):
            inp = x
            y = self.fc(inp)
            z = tf.nn.relu(y)
            return z

And at the Decoder:在解码器处:

# (...)
    @tf.function(input_signature = [tf.TensorSpec(shape=[1, 1], dtype=tf.int32), tf.TensorSpec(shape=[1, 64, 256], dtype=tf.float32),tf.TensorSpec(shape=[1, 512], dtype=tf.float32)])
    def __call__(self, x, features1, hidden):
        # defining attention as a separate model
        features1 = features1
        hidden1 = hidden
        # (...)

The BahdanauAttention class doesn't need @tf.function because it is inferred from the call at the decoder. BahdanauAttention class 不需要@tf.function,因为它是从解码器的调用中推断出来的。 Once you define all your input signatures like that, you only need to call tf.saved_model.save(decoder, 'decoder') .一旦你像这样定义了所有输入签名,你只需要调用tf.saved_model.save(decoder, 'decoder')

Furthermore, after you have properly set the input signatures, you still get the error:此外,在正确设置输入签名后,您仍然会收到错误消息:

TypeError: call() missing 1 required positional argument: 'hidden1'

I believe that is a bug in Tensorflow.我相信这是 Tensorflow 中的一个错误。 To resolve this problem, you have to use __call__ instead of call in your classes.要解决此问题,您必须在类中使用__call__而不是call My errors were resolved only when I changed it in my code.只有当我在代码中更改它时,我的错误才得到解决。 Then I could save the Encoder and Decoder and I could successfully run the Tensorflow conversion script on the model.然后我可以保存编码器和解码器,我可以在 model 上成功运行 Tensorflow 转换脚本。

I couldn't find any documentation why that happens, but I hope this helps someone else as well.我找不到任何文档为什么会发生这种情况,但我希望这对其他人也有帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM