简体   繁体   English

加载保存的 BertClassifer model

[英]Loading a saved BertClassifer model

I am using the following example from this colab notebook: https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb我正在使用此 colab 笔记本中的以下示例: https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb

It saves a fine-tuned model using the model.save() functionality.它使用 model.save() 功能保存微调的 model。

I am trying to load that same model using tf.keras.models.load_model() but am getting the following error:我正在尝试使用 tf.keras.models.load_model() 加载相同的 model 但收到以下错误:

KeyError: 'name'
KeyError                                  Traceback (most recent call last)
in engine
----> 1 model = tf.keras.models.load_model('./saved_model/test')

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
    183     if isinstance(filepath, six.string_types):
    184       loader_impl.parse_saved_model(filepath)
--> 185       return saved_model_load.load(filepath, compile)
    186 
    187   raise IOError(

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load(path, compile)
    114   # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
    115   # TODO(kathywu): Add code to load from objects that contain all endpoints
--> 116   model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
    117 
    118   # pylint: disable=protected-access

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, loader_cls)
    605       loader = loader_cls(object_graph_proto,
    606                           saved_model_proto,
--> 607                           export_dir)
    608       root = loader.get(0)
    609       if isinstance(loader, Loader):

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in __init__(self, *args, **kwargs)
    186     self._models_to_reconstruct = []
    187 
--> 188     super(KerasObjectLoader, self).__init__(*args, **kwargs)
    189 
    190     # Now that the node object has been fully loaded, and the checkpoint has

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir)
    121       self._concrete_functions[name] = _WrapperFunction(concrete_function)
    122 
--> 123     self._load_all()
    124     self._restore_checkpoint()
    125 

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_all(self)
    207     # loaded from config may create variables / other objects during
    208     # initialization. These are recorded in `_nodes_recreated_from_config`.
--> 209     self._layer_nodes = self._load_layers()
    210 
    211     # Load all other nodes and functions.

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_layers(self)
    307         continue
    308 
--> 309       layers[node_id] = self._load_layer(proto.user_object, node_id)
    310 
    311     for node_id, proto in metric_list:

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_layer(self, proto, node_id)
    333     # Detect whether this object can be revived from the config. If not, then
    334     # revive from the SavedModel instead.
--> 335     obj, setter = self._revive_from_config(proto.identifier, metadata, node_id)
    336     if obj is None:
    337       obj, setter = revive_custom_object(proto.identifier, metadata)

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_from_config(self, identifier, metadata, node_id)
    350     else:
    351       obj = (
--> 352           self._revive_graph_network(metadata, node_id) or
    353           self._revive_layer_from_config(metadata, node_id))
    354 

/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_graph_network(self, metadata, node_id)
    386     else:
    387       model = models_lib.Functional(
--> 388           inputs=[], outputs=[], name=config['name'])
    389 
    390     # Record this model and its layers. This will later be used to reconstruct

KeyError: 'name'

Can someone please advise what I am doing wrong or if I should be doing something else to load this model?有人可以告诉我做错了什么,或者我是否应该做其他事情来加载这个 model? Thanks!谢谢!

As it stated, name is not defined.正如它所说, name没有定义。 model = tf.keras.models.load_model('./saved_model/test') leads to the PATH (folder) ./saved_model/test , not the PATH of the model. model = tf.keras.models.load_model('./saved_model/test')导致路径(文件夹) ./saved_model/test ,而不是 Z20F4C35E630DAF3944DDB8CF35E630DAF3944DFA88 的路径

I've documented my experience of that KeyError (and the solution I found) here:我在这里记录了我对KeyError的体验(以及我找到的解决方案):

https://github.com/tensorflow/models/issues/8692#issuecomment-727033061 https://github.com/tensorflow/models/issues/8692#issuecomment-727033061

It was due to a bad implementation of a custom layer's get_config() that didn't use the parent's base config, that does contain the name key, necessary for keras reload.这是由于自定义层的get_config()实现不正确,它没有使用父级的基本配置,它确实包含name键,这是 keras 重新加载所必需的。

Since the provided notebook is no longer available I will assume that you were using this (oficial?) BERT implementation for thensorflow since I was having the same issue with it.由于提供的笔记本不再可用,我假设您正在使用这个(官方?) BERT 实现,因为我遇到了同样的问题。


TL;DR: You can't load the model due to missing proper get_config methods in the subclassed model and (possibly) this layer . TL;DR:由于子类 model和(可能) 此层中缺少正确的get_config方法,您无法加载 model。 You need to edit the BERT encoder class (if this is a possibility for you).您需要编辑 BERT 编码器 class(如果您可以这样做)。


I came across this while trying to modify the BERT model and apply the Masked Language Modeling training.我在尝试修改BERT model并应用蒙面语言建模培训时遇到了这个问题。 I found that I was not able to load the checkpoints due to the mentioned KeyError: 'name' .我发现由于提到的KeyError: 'name' ,我无法加载检查点。 I found three issues:我发现了三个问题:

  1. After a lot of testing I found that several layers where producing a warning during saving, the Warning stated Custom mask layers require a config and must override... I fixed the layers in my implementation, however the problem persisted.经过大量测试后,我发现在保存期间产生警告的几个层,警告指出Custom mask layers require a config and must override...我在我的实现中修复了这些层,但问题仍然存在。
  2. Later I realized that the KeyError: 'name' came from the model's get_config implementation which was not yielding the tf.keras.Model original config along with the new config dict .后来我意识到 KeyError KeyError: 'name'来自模型的get_config 实现,它没有产生tf.keras.Model原始配置以及新的配置 dict In this particular case it was not due to missing layer configuration (as it is often the case).在这种特殊情况下,这不是由于缺少层配置(通常是这种情况)。 I patched the subclassed model get_config method in my implementation:我在我的实现中修补了子类 model get_config方法:

def get_config(self):
    # Implement get_config to enable serialization. This is NOT optional.
    base_config = super(LayoutEncoder, self).get_config()
    return dict(list(self._config.items()) + list(base_config.items()))
  1. Solving the KeyError: 'name' (using the code in point 2.) created a new and more horrible error (also during loading/saving).解决KeyError: 'name' (使用第 2 点中的代码)创建了一个新的更可怕的错误(也在加载/保存期间)。 More than 500 lines of errors/warnings containing WARNING:tensorflow:Unresolved object in checkpoint... and Inconsistent references when loading the checkpoint into this object graph... appeared and the loading still failed.超过 500 行错误/警告包含WARNING:tensorflow:Unresolved object in checkpoint...Inconsistent references when loading the checkpoint into this object graph... I found that this was caused by the way that the BERT model was keeping track of the layers (lines 484-491).我发现这是由 BERT model 跟踪层的方式引起的(第 484-491 行)。 These layers were also serialized during save.这些层也在保存期间被序列化。 I'm not sure why these members are added to the encoder class, however by removing them I was able to save and load the model without any other problem.我不确定为什么将这些成员添加到编码器 class 中,但是通过删除它们,我能够保存和加载 model 而没有任何其他问题。

Sorry for the long history but, imho, this was the only way to explain it...对不起,历史悠久,但恕我直言,这是解释它的唯一方法......

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM