[英]Loading a saved BertClassifer model
I am using the following example from this colab notebook: https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb我正在使用此 colab 笔记本中的以下示例: https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb
It saves a fine-tuned model using the model.save() functionality.它使用 model.save() 功能保存微调的 model。
I am trying to load that same model using tf.keras.models.load_model() but am getting the following error:我正在尝试使用 tf.keras.models.load_model() 加载相同的 model 但收到以下错误:
KeyError: 'name'
KeyError Traceback (most recent call last)
in engine
----> 1 model = tf.keras.models.load_model('./saved_model/test')
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
183 if isinstance(filepath, six.string_types):
184 loader_impl.parse_saved_model(filepath)
--> 185 return saved_model_load.load(filepath, compile)
186
187 raise IOError(
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load(path, compile)
114 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
115 # TODO(kathywu): Add code to load from objects that contain all endpoints
--> 116 model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
117
118 # pylint: disable=protected-access
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, loader_cls)
605 loader = loader_cls(object_graph_proto,
606 saved_model_proto,
--> 607 export_dir)
608 root = loader.get(0)
609 if isinstance(loader, Loader):
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in __init__(self, *args, **kwargs)
186 self._models_to_reconstruct = []
187
--> 188 super(KerasObjectLoader, self).__init__(*args, **kwargs)
189
190 # Now that the node object has been fully loaded, and the checkpoint has
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir)
121 self._concrete_functions[name] = _WrapperFunction(concrete_function)
122
--> 123 self._load_all()
124 self._restore_checkpoint()
125
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_all(self)
207 # loaded from config may create variables / other objects during
208 # initialization. These are recorded in `_nodes_recreated_from_config`.
--> 209 self._layer_nodes = self._load_layers()
210
211 # Load all other nodes and functions.
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_layers(self)
307 continue
308
--> 309 layers[node_id] = self._load_layer(proto.user_object, node_id)
310
311 for node_id, proto in metric_list:
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_layer(self, proto, node_id)
333 # Detect whether this object can be revived from the config. If not, then
334 # revive from the SavedModel instead.
--> 335 obj, setter = self._revive_from_config(proto.identifier, metadata, node_id)
336 if obj is None:
337 obj, setter = revive_custom_object(proto.identifier, metadata)
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_from_config(self, identifier, metadata, node_id)
350 else:
351 obj = (
--> 352 self._revive_graph_network(metadata, node_id) or
353 self._revive_layer_from_config(metadata, node_id))
354
/home/cdsw/.local/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_graph_network(self, metadata, node_id)
386 else:
387 model = models_lib.Functional(
--> 388 inputs=[], outputs=[], name=config['name'])
389
390 # Record this model and its layers. This will later be used to reconstruct
KeyError: 'name'
Can someone please advise what I am doing wrong or if I should be doing something else to load this model?有人可以告诉我做错了什么,或者我是否应该做其他事情来加载这个 model? Thanks!谢谢!
As it stated, name
is not defined.正如它所说, name
没有定义。 model = tf.keras.models.load_model('./saved_model/test')
leads to the PATH (folder) ./saved_model/test
, not the PATH of the model. model = tf.keras.models.load_model('./saved_model/test')
导致路径(文件夹) ./saved_model/test
,而不是 Z20F4C35E630DAF3944DDB8CF35E630DAF3944DFA88 的路径
I've documented my experience of that KeyError
(and the solution I found) here:我在这里记录了我对KeyError
的体验(以及我找到的解决方案):
https://github.com/tensorflow/models/issues/8692#issuecomment-727033061 https://github.com/tensorflow/models/issues/8692#issuecomment-727033061
It was due to a bad implementation of a custom layer's get_config()
that didn't use the parent's base config, that does contain the name
key, necessary for keras reload.这是由于自定义层的get_config()
实现不正确,它没有使用父级的基本配置,它确实包含name
键,这是 keras 重新加载所必需的。
Since the provided notebook is no longer available I will assume that you were using this (oficial?) BERT implementation for thensorflow since I was having the same issue with it.由于提供的笔记本不再可用,我假设您正在使用这个(官方?) BERT 实现,因为我遇到了同样的问题。
TL;DR: You can't load the model due to missing proper get_config
methods in the subclassed model and (possibly) this layer . TL;DR:由于子类 model和(可能) 此层中缺少正确的get_config
方法,您无法加载 model。 You need to edit the BERT encoder class (if this is a possibility for you).您需要编辑 BERT 编码器 class(如果您可以这样做)。
I came across this while trying to modify the BERT model and apply the Masked Language Modeling training.我在尝试修改BERT model并应用蒙面语言建模培训时遇到了这个问题。 I found that I was not able to load the checkpoints due to the mentioned KeyError: 'name'
.我发现由于提到的KeyError: 'name'
,我无法加载检查点。 I found three issues:我发现了三个问题:
Custom mask layers require a config and must override...
I fixed the layers in my implementation, however the problem persisted.经过大量测试后,我发现在保存期间产生警告的几个层,警告指出Custom mask layers require a config and must override...
我在我的实现中修复了这些层,但问题仍然存在。KeyError: 'name'
came from the model's get_config implementation which was not yielding the tf.keras.Model
original config along with the new config dict .后来我意识到 KeyError KeyError: 'name'
来自模型的get_config 实现,它没有产生tf.keras.Model
原始配置以及新的配置 dict 。 In this particular case it was not due to missing layer configuration (as it is often the case).在这种特殊情况下,这不是由于缺少层配置(通常是这种情况)。 I patched the subclassed model get_config
method in my implementation:我在我的实现中修补了子类 model get_config
方法:
def get_config(self):
# Implement get_config to enable serialization. This is NOT optional.
base_config = super(LayoutEncoder, self).get_config()
return dict(list(self._config.items()) + list(base_config.items()))
KeyError: 'name'
(using the code in point 2.) created a new and more horrible error (also during loading/saving).解决KeyError: 'name'
(使用第 2 点中的代码)创建了一个新的更可怕的错误(也在加载/保存期间)。 More than 500 lines of errors/warnings containing WARNING:tensorflow:Unresolved object in checkpoint...
and Inconsistent references when loading the checkpoint into this object graph...
appeared and the loading still failed.超过 500 行错误/警告包含WARNING:tensorflow:Unresolved object in checkpoint...
和Inconsistent references when loading the checkpoint into this object graph...
I found that this was caused by the way that the BERT model was keeping track of the layers (lines 484-491).我发现这是由 BERT model 跟踪层的方式引起的(第 484-491 行)。 These layers were also serialized during save.这些层也在保存期间被序列化。 I'm not sure why these members are added to the encoder class, however by removing them I was able to save and load the model without any other problem.我不确定为什么将这些成员添加到编码器 class 中,但是通过删除它们,我能够保存和加载 model 而没有任何其他问题。Sorry for the long history but, imho, this was the only way to explain it...对不起,历史悠久,但恕我直言,这是解释它的唯一方法......
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.