[英]How do I set custom weights for my sequential model?
我想將我的 model 的權重設置為隨機正態分布中的非常大的數字。 這是我目前的解決方案:
weights = tf.keras.initializers.random_normal()
weights = weights(shape=(2, 5)).numpy() * 100
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, activation="tanh", input_shape=(X_train.shape[1],), kernel_initializer=weights),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid")
])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
loss="mse",
metrics=["accuracy"])
history = model.fit(X_train, y_train, epochs=100, validation_data=[X_test, y_test])
這導致以下 output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-145-2307b7a2c402> in <module>()
3
4 model = tf.keras.Sequential([
----> 5 tf.keras.layers.Dense(5, activation="tanh", input_shape=(X_train.shape[1],), kernel_initializer=weights),
6 tf.keras.layers.Dense(2, activation="tanh"),
7 tf.keras.layers.Dense(2, activation="tanh"),
1 frames
/usr/local/lib/python3.7/dist-packages/keras/initializers/__init__.py in get(identifier)
191 else:
192 raise ValueError('Could not interpret initializer identifier: ' +
--> 193 str(identifier))
ValueError: Could not interpret initializer identifier: [[ 1.8304478 -1.3845474 -2.438812 -7.1097493 6.8744435 ]
[ 3.2775316 0.75484884 -0.7150349 1.852715 -8.842371 ]]
嘗試將其用於Keras
層時,使用tf.keras.initializers.random_normal()
將不起作用。 例如,查看此處的文檔。 此外,您不應該事先對權重的形狀進行硬編碼。 它將根據 model 的輸入進行推斷。 你可以嘗試這樣的事情:
import tensorflow as tf
def random_normal_init(shape, dtype=None):
return tf.random.normal(shape) * 100
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, activation="tanh", input_shape=(5,), kernel_initializer=random_normal_init),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid")
])
samples = 20
print(model(tf.random.normal((samples, 5))))
tf.Tensor(
[[0.2567306 ]
[0.79331714]
[0.74326944]
[0.35187328]
[0.18808913]
[0.81191087]
[0.6069946 ]
[0.74326944]
[0.65107304]
[0.39300534]
[0.6069946 ]
[0.81191087]
[0.61664075]
[0.35496145]
[0.81191087]
[0.2567306 ]
[0.38335925]
[0.2567306 ]
[0.50955486]
[0.74326944]], shape=(20, 1), dtype=float32)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.