简体   繁体   中英

Reuse trained weights in TensorFlow model without reinitialization

I have a TensorFlow model that looks roughly like:

class MyModel():

def predict(self, x):
    with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):
        W_1 = tf.get_variable("weight", shape=[64,1], dtype=tf.float64)
        b_1 = tf.get_variable("bias", shape=[1], dtype=tf.float64)
        y_hat = tf.matmul(x, W_1) + b_1
    return y_hat


def train_step(self, x, y):
    with tf.variable_scope("optimization"):
        y_hat = self.predict(x)
        loss = tf.losses.mean_squared_error(y, y_hat)
        optimizer = tf.train.AdamOptimizer()
        train_step = optimizer.minimize(loss)
    return train_step


def __call__(self, x):
    return self.predict(x)

I can instantiate the model like my_model = MyModel() and then train it using sess.run(my_model.train_step(x, y)) , but if I want to predict on a different tensor after training like sess.run(my_model.predict(x_new)) , I get a FailedPreconditionError .

It seems like the object's __call__ function is not reusing the weights as intended, but adds new weights to the graph, which are then uninitialized. Is there a way to avoid this behaviour?

The convention is to define the weights as attributes of the network and not inside the predict function, same remark for the optimizer and train_step. Maybe it could help because train_step = optimizer.minimize(loss) look at the whole graph.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM