简体   繁体   English

tensorflow 中是否有类似的 function,如 Pytorch 中的 load_state_dict()?

[英]Is there a similar function in tensorflow like load_state_dict() in Pytorch?

Like it has been described, I am wondering is there a similar function in tensorflow for load_state_dict() like the one does in Pytorch.就像已经描述的那样,我想知道 tensorflow 中是否有类似的 function 用于 load_state_dict(),就像 Pytorch 中的一样。 To demonstrate a scenario, please refer to the code following:为了演示一个场景,请参考以下代码:

# Suppose we have two correctly initialized neural networks: net2 and net1
# Using Pytorch
net2.load_state_dict(net1.state_dict())

Does anyone have any idea?有人有什么主意吗?

Below code may help in achieveing the same in tensorflow:下面的代码可能有助于在 tensorflow 中实现相同的目标:

Save the model保存 model

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta


To Restore the model恢复 model

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM