[英]How do I set custom weights for my sequential model?
我想将我的 model 的权重设置为随机正态分布中的非常大的数字。 这是我目前的解决方案:
weights = tf.keras.initializers.random_normal()
weights = weights(shape=(2, 5)).numpy() * 100
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, activation="tanh", input_shape=(X_train.shape[1],), kernel_initializer=weights),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid")
])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
loss="mse",
metrics=["accuracy"])
history = model.fit(X_train, y_train, epochs=100, validation_data=[X_test, y_test])
这导致以下 output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-145-2307b7a2c402> in <module>()
3
4 model = tf.keras.Sequential([
----> 5 tf.keras.layers.Dense(5, activation="tanh", input_shape=(X_train.shape[1],), kernel_initializer=weights),
6 tf.keras.layers.Dense(2, activation="tanh"),
7 tf.keras.layers.Dense(2, activation="tanh"),
1 frames
/usr/local/lib/python3.7/dist-packages/keras/initializers/__init__.py in get(identifier)
191 else:
192 raise ValueError('Could not interpret initializer identifier: ' +
--> 193 str(identifier))
ValueError: Could not interpret initializer identifier: [[ 1.8304478 -1.3845474 -2.438812 -7.1097493 6.8744435 ]
[ 3.2775316 0.75484884 -0.7150349 1.852715 -8.842371 ]]
尝试将其用于Keras
层时,使用tf.keras.initializers.random_normal()
将不起作用。 例如,查看此处的文档。 此外,您不应该事先对权重的形状进行硬编码。 它将根据 model 的输入进行推断。 你可以尝试这样的事情:
import tensorflow as tf
def random_normal_init(shape, dtype=None):
return tf.random.normal(shape) * 100
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, activation="tanh", input_shape=(5,), kernel_initializer=random_normal_init),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(2, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid")
])
samples = 20
print(model(tf.random.normal((samples, 5))))
tf.Tensor(
[[0.2567306 ]
[0.79331714]
[0.74326944]
[0.35187328]
[0.18808913]
[0.81191087]
[0.6069946 ]
[0.74326944]
[0.65107304]
[0.39300534]
[0.6069946 ]
[0.81191087]
[0.61664075]
[0.35496145]
[0.81191087]
[0.2567306 ]
[0.38335925]
[0.2567306 ]
[0.50955486]
[0.74326944]], shape=(20, 1), dtype=float32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.