[英]Tensorflow LSTM Gate weights
您好我有關於Tensorflow的問題。 我訓練了一些LSTM模型,我可以訪問突觸連接的權重和偏差,但是我似乎無法訪問LSTM單元的輸入,新輸入,輸出和忘記門權重。 我可以獲得門限張量,但當我在會話中嘗試.eval()時,我會得到錯誤。 我正在使用tensorflow / python / ops / rnn_cell.py中的BasicLSTMCell類來為我的網絡
`
class BasicLSTMCell(RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full LSTMCell that follows.
"""
def __init__(self, num_units, forget_bias=1.0, input_size=None,
state_is_tuple=True, activation=tanh):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(1, 2, state)
concat = _linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat(1, [new_c, new_h])
return new_h, new_state
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_sharded_variable(name, shape, dtype, num_shards):
"""Get a list of sharded variables with the given dtype."""
if num_shards > shape[0]:
raise ValueError("Too many shards: shape=%s, num_shards=%d" %
(shape, num_shards))
unit_shard_size = int(math.floor(shape[0] / num_shards))
remaining_rows = shape[0] - unit_shard_size * num_shards
shards = []
for i in range(num_shards):
current_size = unit_shard_size
if i < remaining_rows:
current_size += 1
shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
dtype=dtype))
return shards
`
我可以看到在def 調用中使用了i,j,f,o門,但是當我執行它們時,我得到了張量,當我在會話中嘗試.eval()時,我得到了錯誤。 我也試過tf.getVariable但是無法提取權重矩陣。 我的問題:有沒有辦法評估i,j,f和o門權重/矩陣?
首先,澄清一些混淆:i,j,f和o張量不是權重矩陣; 它們是依賴於特定LSTM單元輸入的中間計算步驟。 LSTM單元的所有權重都存儲在變量self._kernel和self._bias中,並存儲在常量self._forget_bias中。
因此,為了回答你的問題的兩種可能的解釋,我將展示如何在每一步打印self._kernel和self._bias的值,以及i,j,f和o張量的值。
假設我們有以下圖表:
import numpy as np
import tensorflow as tf
timesteps = 7
num_input = 4
num_units = 3
x_val = np.random.normal(size=(1, timesteps, num_input))
lstm = tf.nn.rnn_cell.BasicLSTMCell(num_units = num_units)
X = tf.placeholder("float", [1, timesteps, num_input])
inputs = tf.unstack(X, timesteps, 1)
outputs, state = tf.contrib.rnn.static_rnn(lstm, inputs, dtype=tf.float32)
如果我們知道它的名字,我們可以找到任何張量的值。 找到張量名稱的一種方法是查看TensorBoard。
init = tf.global_variables_initializer()
graph = tf.get_default_graph()
with tf.Session(graph=graph) as sess:
train_writer = tf.summary.FileWriter('./graph', sess.graph)
sess.run(init)
現在我們可以通過terminal命令啟動TensorBoard
tensorboard --logdir=graph --host=localhost
並發現產生i,j,f,o張量的操作名稱為'rnn / basic_lstm_cell / split',而內核和偏差稱為'rnn / basic_lstm_cell / kernel'和'rnn / basic_lstm_cell / bias':
tf.contrib.rnn.static_rnn函數調用我們的基本lstm單元7次,每次執行一次。 當要求Tensorflow以相同的名稱創建多個操作時,它會為它們添加后綴,如下所示:rnn / basic_lstm_cell / split,rnn / basic_lstm_cell / split_1,...,rnn / basic_lstm_cell / split_6。 這些是我們運營的名稱。
tensorflow中張量的名稱由生成張量的操作的名稱,后跟冒號,后跟生成此張量的操作輸出的索引。 內核和偏置操作具有單個輸出,因此張量名稱將是
kernel = graph.get_tensor_by_name("rnn/basic_lstm_cell/kernel:0")
bias = graph.get_tensor_by_name("rnn/basic_lstm_cell/bias:0")
拆分操作產生四個輸出:i,j,f和o,因此這些張量的名稱將是:
i_list = []
j_list = []
f_list = []
o_list = []
for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6"]:
i_list.append(graph.get_tensor_by_name(
"rnn/basic_lstm_cell/split{}:0".format(suffix)
))
j_list.append(graph.get_tensor_by_name(
"rnn/basic_lstm_cell/split{}:1".format(suffix)
))
f_list.append(graph.get_tensor_by_name(
"rnn/basic_lstm_cell/split{}:2".format(suffix)
))
o_list.append(graph.get_tensor_by_name(
"rnn/basic_lstm_cell/split{}:3".format(suffix)
))
現在我們可以找到所有張量的值:
with tf.Session(graph=graph) as sess:
train_writer = tf.summary.FileWriter('./graph', sess.graph)
sess.run(init)
weights = sess.run([kernel, bias])
print("Weights:\n", weights)
i_values, j_values, f_values, o_values = sess.run([i_list, j_list, f_list, o_list],
feed_dict = {X:x_val})
print("i values:\n", i_values)
print("j values:\n", j_values)
print("f_values:\n", f_values)
print("o_values:\n", o_values)
或者,我們可以通過查看圖中所有張量的列表來找到張量名稱,這可以通過以下方式生成:
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]
print(tensor_names)
或者,對於所有操作的較短列表:
print([node.name for node in graph.get_operations()])
第三種方法是讀取源代碼並找出分配給哪些張量的名稱。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.