![](/img/trans.png)
[英]tf.nn.rnn_cell.MultiRNNCell creates variable shape mismatch?
[英]Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell
我有一個設置,我需要在主要初始化后使用tf.initialize_all_variables()
初始化LSTM。 即我想調用tf.initialize_variables([var_list])
有沒有辦法收集所有內部可訓練變量:
這樣我可以初始化JUST這些參數嗎?
我想要這個的主要原因是因為我不想重新初始化一些訓練過的值。
解決問題的最簡單方法是使用變量范圍。 范圍內變量的名稱將以其名稱為前綴。 這是一個簡短的片段:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
# Retrieve just the LSTM variables.
lstm_variables = [v for v in tf.all_variables()
if v.name.startswith(vs.name)]
# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
它與MultiRNNCell
工作方式相同。
編輯:將tf.trainable_variables
更改為tf.all_variables()
您還可以使用tf.get_collection()
:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
(部分復制自拉法爾的回答)
請注意,最后一行等同於Rafal代碼中的列表推導。
基本上,tensorflow存儲全局變量集合,可以通過tf.all_variables()
或tf.get_collection(tf.GraphKeys.VARIABLES)
。 如果在tf.get_collection()
函數中指定scope
(范圍名稱),則只能獲取范圍在指定范圍內的集合中的張量(在本例中為變量)。
編輯:您也可以使用tf.GraphKeys.TRAINABLE_VARIABLES
來獲取可訓練的變量。 但由於vanilla BasicLSTMCell沒有初始化任何不可訓練的變量,因此兩者在功能上都是等價的。 有關默認圖表集合的完整列表,請查看此項 。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.