![](/img/trans.png)
[英]How do I copy specific layer weights from pretrained models using Tensorflow Keras api?
[英]Upgraded to Tensorflow 2.5 now get a Lambda Layer error when using pretrained Keras Applications Models
我按照本教程為我的問題構建了一個連體網絡。 我使用的是 Tensorflow 2.4.1 現在升級了
這段代碼以前工作得很好
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
現在每個 resnet 層或 mobilenet 或高效網絡(都嘗試過)都會拋出這些錯誤:
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (tf.nn.convolution_620), but
are not present in its tracked objects:
<tf.Variable 'stem_conv/kernel:0' shape=(3, 3, 3, 48) dtype=float32>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
它編譯並且似乎適合。
但是我們必須在 2.5 中以不同的方式初始化模型嗎?
感謝您的任何指點!
此處無需恢復到TF2.4.1
。 我總是建議嘗試使用最新版本,因為它解決了許多性能問題和新功能。
我能夠在TF2.5
中執行上述代碼而沒有任何問題。
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, Model
img_width, img_height = 224, 224
target_shape = (img_width, img_height, 3)
base_cnn = ResNet50(
weights="imagenet", input_shape=target_shape, include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
Output:
2.5.0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step
根據@Olli,重新啟動並清除 session kernel 已解決問題。
我不確定您的問題的主要原因是什么,因為它通常無法重現。 但這里有一些關於該警告信息的說明。 您的問題中顯示的回溯不是來自ResNet
而是來自EfficientNet
。
現在,我們知道存在Lambda
層,因此在構建順序和功能API 模型時,可以使用任意表達式作為Layer
。 Lambda
層最適合簡單操作或快速實驗。 雖然可以將變量與 Lambda 層一起使用,但不鼓勵這種做法,因為它很容易導致錯誤。 例如:
import tensorflow as tf
x_input = tf.range(12.).numpy().reshape(-1, 4)
weights = tf.Variable(tf.random.normal((4, 2)), name='w')
bias = tf.ones((1, 2), name='b')
# lambda custom layer
mylayer1 = tf.keras.layers.Lambda(lambda x: tf.add(tf.matmul(x, weights),
bias), name='lambda1')
mylayer1(x_input)
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (lambda1), but
are not present in its tracked objects:
<tf.Variable 'w:0' shape=(4, 2) dtype=float32, numpy=
array([[-0.753139 , -1.1668463 ],
[-1.3709341 , 0.8887151 ],
[ 0.3157893 , 0.01245957],
[-1.3878908 , -0.38395467]], dtype=float32)>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[ -3.903028 , 0.7617702],
[-16.687727 , -1.8367348],
[-29.472424 , -4.43524 ]], dtype=float32)>
這是因為mylayer1
層不直接跟蹤tf.Variables
,因此這些參數不會出現在mylayer1.trainable_weights
中。
mylayer1.trainable_weights
[]
一般來說, Lambda
層可以方便簡單的無狀態計算,但更復雜的應該使用子類層來代替。 從您的回溯來看, step_conv
層似乎可能存在這種情況。
for layer in EfficientNetB0(weights=None).layers:
if layer.name == 'stem_conv':
print(layer)
<tensorflow.python.keras.layers.convolutional.Conv2D object..
快速調查tf.compat.v1.nn.conv2d的源代碼,導致可能是原因的lambda 表達式。
pip 安裝 tensorflow==2.3.0,為我工作而不是 tf 2.5 我遇到了與使用 Lambda 層相關的問題
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.