簡體   English   中英

如何在Tensorflow中為當前模型恢復預訓練的檢查點?

[英]How to restore pretrained checkpoint for current model in Tensorflow?

我有一個訓練有素的檢查站。 現在,我正在嘗試將此預訓練模型恢復到當前網絡。 但是,變量名不同。 Tensorflow文檔說使用字典像:

v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2": v2})

但是,當前網絡中的變量定義如下:

with tf.variable_scope('a'):
    b=tf.get_variable(......)

因此,變量名稱似乎是a/b 如何使字典像"v2": a/b

您可以使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)獲取當前圖中所有變量名稱的列表。 您還可以指定范圍。

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')

您可以使用tf.train.list_variables(ckpt_file)獲取檢查點中所有變量的列表。

假設您的檢查點中有變量b,並且您想在名稱為a/b tf.variable_scope('a')內部加載。 為此,您只需定義它

with tf.variable_scope('a'):
    b=tf.get_variable(......)

和負載

saver = tf.train.Saver({'v2': b})

with tf.Session() as sess:
    saver.restore(sess, ckpt_file))
    print(b)

這將輸出

<tf.Variable 'a/b:0' shape dtype>

編輯:如前所述,您可以使用

vars_dict = {}
for var_current in tf.global_variables():
    print(var_current)
    print(var_current.op.name) # this gets only name

for var_ckpt in tf.train.list_variables(ckpt):
    print(var_ckpt[0]) this gets only name

當您知道所有變量的確切名稱時,您可以分配所需的任何值,只要變量具有相同的形狀和dtype即可

vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 

您可以顯式地或以您認為合適的任何循環構造此字典。 然后將其傳遞給保護程序

saver = tf.train.Saver(vars_dict)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM