![](/img/trans.png)
[英]tf.keras.models.model vs tf.keras.model
[英]How to copy a tf.keras.models.Model subclass?
I need to copy a keras model and there is no way that I know of which can be done unless the model is not a tf.keras.models.Model()
subclass.
注意:使用copy.deepcopy()
可以正常工作而不会出现任何错误,但是无论何时使用副本都会导致另一个错误。
例子:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.dropout = tf.keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
if training:
x = self.dropout(x, training=training)
return self.dense2(x)
if __name__ == '__main__':
model1 = MyModel()
model2 = tf.keras.models.clone_model(model1)
结果是:
Traceback (most recent call last):
File "/Users/emadboctor/Library/Application Support/JetBrains/PyCharm2020.3/scratches/scratch.py", line 600, in <module>
model2 = tf.keras.models.clone_model(model1)
File "/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/models.py", line 430, in clone_model
return _clone_functional_model(
File "/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/models.py", line 171, in _clone_functional_model
raise ValueError('Expected `model` argument '
ValueError: Expected `model` argument to be a functional `Model` instance, but got a subclass model instead.
Currently, we can't use tf.keras.models.clone_model
for subclassed model API whereas we can for sequential and functional API. 从文档,
model Instance of Model (could be a functional model or a Sequential model).
这是您需要的解决方法。 如果我们需要复制经过训练的 model 是有道理的,我们可以在其中获得一些优化的参数。 因此,主要任务是我们需要通过复制现有的 model 来创建新的 model。 目前这种场景最方便的方法是get
训练的权重并set
为新创建的 model 实例。 让我们首先构建一个 model,对其进行训练,然后将权重矩阵设置为新的 model。
import tensorflow as tf
import numpy as np
class ModelSubClassing(tf.keras.Model):
def __init__(self, num_classes):
super(ModelSubClassing, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")
self.gap = tf.keras.layers.GlobalAveragePooling2D()
self.dense = tf.keras.layers.Dense(num_classes)
def call(self, input_tensor, training=False):
# forward pass: block 1
x = self.conv1(input_tensor)
x = self.gap(x)
return self.dense(x)
def build_graph(self, raw_shape):
x = tf.keras.layers.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
# compile
sub_classing_model = ModelSubClassing(10)
sub_classing_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# plot for debug
tf.keras.utils.plot_model(
sub_classing_model.build_graph(x_train.shape[1:]),
show_shapes=False,
show_dtype=False,
show_layer_names=True,
expand_nested=False,
dpi=96,
)
数据集
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
# train set / data
x_train = np.expand_dims(x_train, axis=-1)
x_train = x_train.astype('float32') / 255
# train set / target
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
# fit
sub_classing_model.fit(x_train, y_train, batch_size=128, epochs=1)
# 469/469 [==============================] - 2s 2ms/step - loss: 8.2821
全新 Model / 副本
对于子类 model,我们必须启动 class object。
sub_classing_model_copy = ModelSubClassing(10)
sub_classing_model_copy.build((x_train.shape))
sub_classing_model_copy.set_weights(sub_classing_model.get_weights()) # <- get and set wg
# plot for debug ; same as original plot
# but know, layer name is no longer same
# i.e. if, old: conv2d_40 , new/copy: conv2d_41
tf.keras.utils.plot_model(
sub_classing_model_copy.build_graph(x_train.shape[1:]),
show_shapes=False,
show_dtype=False,
show_layer_names=True,
expand_nested=False,
dpi=96,
)
def clones(module, N):
Creation of N identical layers.
:param module: module to clone
:param N: number of copies
:return: keras model of module copies
seqm=KM.Sequential()
for i in range(N):
m = copy.deepcopy(module)
m.name=m.name+str(i)
seqm.add(m)
return seqm
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.