![](/img/trans.png)
[英]How to save the structure and weights of trained tensorflow model?
[英]Reuse trained weights in TensorFlow model without reinitialization
我有一個 TensorFlow 模型,大致如下:
class MyModel():
def predict(self, x):
with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):
W_1 = tf.get_variable("weight", shape=[64,1], dtype=tf.float64)
b_1 = tf.get_variable("bias", shape=[1], dtype=tf.float64)
y_hat = tf.matmul(x, W_1) + b_1
return y_hat
def train_step(self, x, y):
with tf.variable_scope("optimization"):
y_hat = self.predict(x)
loss = tf.losses.mean_squared_error(y, y_hat)
optimizer = tf.train.AdamOptimizer()
train_step = optimizer.minimize(loss)
return train_step
def __call__(self, x):
return self.predict(x)
我可以像my_model = MyModel()
那樣實例化模型,然后使用sess.run(my_model.train_step(x, y))
訓練它,但是如果我想在訓練后預測不同的張量,如sess.run(my_model.predict(x_new))
,我得到一個FailedPreconditionError
。
似乎對象的__call__
函數沒有按預期重用權重,而是向圖中添加了新的權重,然后未初始化。 有沒有辦法避免這種行為?
約定是將權重定義為網絡的屬性,而不是在預測函數內部,優化器和 train_step 的注釋相同。 也許它會有所幫助,因為train_step = optimizer.minimize(loss)
查看整個圖。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.