简体   繁体   中英

Weights and Bias dimensions in TensorFlow for LSTM

I am confused about the dimensions of the hidden weight array for an LSTM. I understand the input weight and bias dimension just fine. Now I just recently started learning about RNNs and LSTMs so maybe I do not fully understand their operation. Here is my understanding. A LSTM layer has a number of cells which hold weights and biases corresponding to gates. A sequence of data is inputted to the first cell, time-step by time-step. During this process, hidden states are created at each step of the sequence which get passed back into the cell. At the end of the sequence, a cell state is calculated which gets passed to the next cell, and the process repeats on this next cell with the same input data. I hope I have this right. Now for an example to explain my confusion.

inputs = Input(shape=(100,4))
x = LSTM(2)(x)
outputs = Dense(1)(x)

model = Model(inputs=inputs, outputs=outputs)

This LSTM will have this summary:


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 100, 4)]            0         
                                                                 
 lstm (LSTM)                 (None, 2)                 56        
                                                                 
 dense (Dense)               (None, 1)                 3        
                                                                 
=================================================================
Total params: 59
Trainable params: 59
Non-trainable params: 0
_________________________________________________________________

From my understanding, each LSTM cell has 8 weights and 4 biases. 4 of the weights are applied to the input and the other 4 are applied to cell's hidden state. These weights and biases are constant throughout the sequence of the input. Starting with the bias. Each cell has 4 biases, so the total number of biases in the LSTM layer is equal to (number_of_cells*4). In this case (2*4)=8.

print(model.layers[1].get_weights()[2].shape)

(8,)

That checks out.

Now the input weights. Each cell has 4 input weights that themselves consist of the same number of weights as input features. And that will need to be multiplied by the number of cells. So the total number of input weights is (number_of_features*4*number_of_cells). In this case (4*4*2)=32.

print(model.layers[1].get_weights()[0].shape)

(4, 8)

That also checks out.

Now here's the confusing part. There are also 4 hidden weights per cell. And the number of hidden states is equal to the sequence length. But the sequence length shouldn't matter because the weights are held constant throughout it. What I found was that the total number of hidden weights is (number_of_cells*4*number_of_cells). In this case (2*4*2)=16.

print(model.layers[1].get_weights()[1].shape)

(2, 8)

Considering there are 2 cells, this means that there must be 8 hidden weights per cell. But this conflicts with what I know. There are only 4 hidden weights per cell. What accounts for the extra 4 hidden weights per cell? From what I understand, the last hidden state of a cell has no effect on the net cell. There is, however, a cell state that does get passed to the next cell at the end of the sequence. But I can't find anywhere online that mentions a weight for the cell state.

I hope I made it clear where my confusion is. Thank you.

Here are the relevant equations from the Wiki on LSTM

Notice, that as you said, there are 4 sets of input (W), hidden (U) weights and biases (b). Note that due to Hadamard product, i, fo, c, h and all biases should have identical dimensions.

Dimensions of your input vector is (4,), hidden vector - (2,).

Each input weight matrix is (4, 2), and there are 4 of them, so 4x2x4 = 32 input weights in total.

With hidden weights, they are (2, 2), and there are 4 of them, so 2x2x4 = 16 hidden weights.

Each bias is (2,) vector, so 2x4 = 8 biases together.

Cell state is gated from hidden state and input. It serves as a long memory (L), while hidden state is a short (S) one. Note the 4th equation, it is a simple linear update C[t] = F[t] x C[t-1] + K[t], which is not the case for the hidden state.

I suggest you check Colah's blog on the subject of LSTM . This is probably the best resource on the subject that I have ever stumbled upon.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM