[英]Tensorflow 2 variable not trainable
我在 tf2 中創建了一個簡單的模型,它將輸入“a”乘以變量“b”(初始化為 1)並返回輸出“c”。 然后我嘗試在簡單的數據集 a=1, c=5 上訓練它。 我希望它學習 b=5。
import tensorflow as tf
from tensorflow.keras.models import Model
a = Input(shape=(1,))
b = tf.Variable(1., trainable=True)
c = a*b
model = Model(a,c)
loss = tf.keras.losses.MeanAbsoluteError()
model.compile(optimizer='adam', loss=loss)
model.fit([1.],[5.],batch_size=1, epochs=1)
但是,tf2 不認為變量 'b' 是可訓練的。 摘要顯示沒有可訓練的參數。
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 1)] 0
_________________________________________________________________
tf_op_layer_mul (TensorFlowO [(None, 1)] 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
為什么變量“b”沒有訓練?
Keras 模型是Layer類的包裝器。 您必須將此變量包裝為 keras 層,以便將其顯示為模型中的可訓練參數。
你可以像這樣創建一個小的自定義層:
class MyLayer(tf.keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__()
#your variable goes here
self.variable = tf.Variable(1., trainable=True, dtype=tf.float64)
def call(self, inputs, **kwargs):
# your mul operation goes here
x = inputs * self.variable
return x
這里call
方法會做乘法運算。 我們可以像使用 out 模型中的任何其他層一樣使用此層。 在這里,我創建了一個Sequential模型,將 aboce 乘法運算添加為模型層。
model = tf.keras.models.Sequential()
mylayer_object = MyLayer()
model.add(mylayer_object)
loss = tf.keras.losses.MeanAbsoluteError()
model.compile("adam", loss)
model.fit([1.],[5.],batch_size=1, epochs=1)
model.summary()
'''
Train on 1 samples
1/1 [==============================] - 0s 426ms/sample - loss: 4.0000
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my_layer (MyLayer) multiple 1
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
'''
在此之后,如果您可以列出模型的可訓練參數。
print(model.trainable_variables)
# [<tf.Variable 'Variable:0' shape=() dtype=float64, numpy=1.0009999968852092>]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.