简体   繁体   English

了解Tensorflow中权重和偏差的初始化

[英]Understanding initialization of weights and biases in Tensorflow

Presently I am using stacked LSTM and FC layer at last. 目前,我终于使用了堆叠的LSTM和FC层。

tf.contrib.rnn.BasicLSTMCell
tf.contrib.rnn.MultiRNNCell
tf.nn.dynamic_rnn
tf.contrib.layers.fully_connected

According to my understanding, if I try to use any architecture defined under 根据我的理解,如果我尝试使用在

tf.nn

class then weight initialization like 类然后权重初始化

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32) 

needs to be done and tf.matmul has to be used. 需要完成,并且必须使用tf.matmul

But for class under 但是对于下课

tf.contrib

weight initialization happens automatically and tf.matmul may not be required. 重量初始化自动发生, 并且可能不需要tf.matmul

Is my understanding correct? 我的理解正确吗? Please let me know 请告诉我

In your case the creation of weight and bias variables happens inside the cell, and then inside the fully connected layer. 在您的情况下,权重和偏差变量的创建发生在单元内部,然后在完全连接的层内部。 So, there is no need to define them explicitly. 因此,无需显式定义它们。 Also, when you are building the graph, Tensorflow doesn't initialize any of your variables. 此外,在构建图形时,Tensorflow不会初始化任何变量。 After, before you start executing nodes in the graph, you have you initialize the variables in the graph beforehand. 之后,在开始执行图中的节点之前,您需要预先初始化图中的变量。 Have a look: 看一看:

# Some input
inputs = np.ones((5,4,3), dtype=np.float32)
# Cleaning the default graph before building it
tf.reset_default_graph()
# Creating a MultiRNNCell (The cell itself creates the weight and bias matrices)
cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(10) for _ in range(2)])
# A wrapper around the cells to ensure the recurrent relation
hidden_states, cell_state = tf.nn.dynamic_rnn(cell=cells, inputs=inputs, dtype=tf.float32)
# Fully connected layer (What it does is create a W and b matrices and performs tf.matmul and then tf.add)
logits = tf.contrib.layers.fully_connected(hidden_states[-1], num_outputs=2, activation_fn=None)
# A new session
sess = tf.Session()
# All variables that were created above (implicitly, by creating a cell and a fully connected layer) will be initialized
sess.run(tf.global_variables_initializer())
# Executing the last node in the graph
sess.run(logits)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM