[英]How do I copy specific layer weights from pretrained models using Tensorflow Keras api?
[英]Upgraded to Tensorflow 2.5 now get a Lambda Layer error when using pretrained Keras Applications Models
我按照本教程为我的问题构建了一个连体网络。 我使用的是 Tensorflow 2.4.1 现在升级了
这段代码以前工作得很好
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
现在每个 resnet 层或 mobilenet 或高效网络(都尝试过)都会抛出这些错误:
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (tf.nn.convolution_620), but
are not present in its tracked objects:
<tf.Variable 'stem_conv/kernel:0' shape=(3, 3, 3, 48) dtype=float32>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
它编译并且似乎适合。
但是我们必须在 2.5 中以不同的方式初始化模型吗?
感谢您的任何指点!
此处无需恢复到TF2.4.1
。 我总是建议尝试使用最新版本,因为它解决了许多性能问题和新功能。
我能够在TF2.5
中执行上述代码而没有任何问题。
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, Model
img_width, img_height = 224, 224
target_shape = (img_width, img_height, 3)
base_cnn = ResNet50(
weights="imagenet", input_shape=target_shape, include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
Output:
2.5.0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step
根据@Olli,重新启动并清除 session kernel 已解决问题。
我不确定您的问题的主要原因是什么,因为它通常无法重现。 但这里有一些关于该警告信息的说明。 您的问题中显示的回溯不是来自ResNet
而是来自EfficientNet
。
现在,我们知道存在Lambda
层,因此在构建顺序和功能API 模型时,可以使用任意表达式作为Layer
。 Lambda
层最适合简单操作或快速实验。 虽然可以将变量与 Lambda 层一起使用,但不鼓励这种做法,因为它很容易导致错误。 例如:
import tensorflow as tf
x_input = tf.range(12.).numpy().reshape(-1, 4)
weights = tf.Variable(tf.random.normal((4, 2)), name='w')
bias = tf.ones((1, 2), name='b')
# lambda custom layer
mylayer1 = tf.keras.layers.Lambda(lambda x: tf.add(tf.matmul(x, weights),
bias), name='lambda1')
mylayer1(x_input)
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (lambda1), but
are not present in its tracked objects:
<tf.Variable 'w:0' shape=(4, 2) dtype=float32, numpy=
array([[-0.753139 , -1.1668463 ],
[-1.3709341 , 0.8887151 ],
[ 0.3157893 , 0.01245957],
[-1.3878908 , -0.38395467]], dtype=float32)>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[ -3.903028 , 0.7617702],
[-16.687727 , -1.8367348],
[-29.472424 , -4.43524 ]], dtype=float32)>
这是因为mylayer1
层不直接跟踪tf.Variables
,因此这些参数不会出现在mylayer1.trainable_weights
中。
mylayer1.trainable_weights
[]
一般来说, Lambda
层可以方便简单的无状态计算,但更复杂的应该使用子类层来代替。 从您的回溯来看, step_conv
层似乎可能存在这种情况。
for layer in EfficientNetB0(weights=None).layers:
if layer.name == 'stem_conv':
print(layer)
<tensorflow.python.keras.layers.convolutional.Conv2D object..
快速调查tf.compat.v1.nn.conv2d的源代码,导致可能是原因的lambda 表达式。
pip 安装 tensorflow==2.3.0,为我工作而不是 tf 2.5 我遇到了与使用 Lambda 层相关的问题
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.