簡體   English   中英

Tensorflow:如何從rnn_cell.BasicLSTM&rnn_cell.MultiRNNCell獲取所有變量

[英]Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

我有一個設置,我需要在主要初始化后使用tf.initialize_all_variables()初始化LSTM。 即我想調用tf.initialize_variables([var_list])

有沒有辦法收集所有內部可訓練變量:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

這樣我可以初始化JUST這些參數嗎?

我想要這個的主要原因是因為我不想重新初始化一些訓練過的值。

解決問題的最簡單方法是使用變量范圍。 范圍內變量的名稱將以其名稱為前綴。 這是一個簡短的片段:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

它與MultiRNNCell工作方式相同。

編輯:將tf.trainable_variables更改為tf.all_variables()

您還可以使用tf.get_collection()

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(部分復制自拉法爾的回答)

請注意,最后一行等同於Rafal代碼中的列表推導。

基本上,tensorflow存儲全局變量集合,可以通過tf.all_variables()tf.get_collection(tf.GraphKeys.VARIABLES) 如果在tf.get_collection()函數中指定scope (范圍名稱),則只能獲取范圍在指定范圍內的集合中的張量(在本例中為變量)。

編輯:您也可以使用tf.GraphKeys.TRAINABLE_VARIABLES來獲取可訓練的變量。 但由於vanilla BasicLSTMCell沒有初始化任何不可訓練的變量,因此兩者在功能上都是等價的。 有關默認圖表集合的完整列表,請查看此項

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM