简体   繁体   中英

Keras: Understanding the number of trainable LSTM parameters

I have run a Keras LSTM demo containing the following code (after line 166):

m = 1
model=Sequential()
dim_in = m
dim_out = m
nb_units = 10

model.add(LSTM(input_shape=(None, dim_in),
                    return_sequences=True, 
                    units=nb_units))
model.add(TimeDistributed(Dense(activation='linear', units=dim_out)))
model.compile(loss = 'mse', optimizer = 'rmsprop')

When I prepend a call to model.summary() , I see the following output:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_4 (LSTM)                (None, None, 10)          480       
_________________________________________________________________
time_distributed_4 (TimeDist (None, None, 1)           11        
=================================================================
Total params: 491
Trainable params: 491
Non-trainable params: 0

I understand that the 11 params of the time distributed layer simply consist of nb_units weights plus one bias value.

Now for the LSTM layer: These answers say:

params = 4 * ((input_size + 1) * output_size + output_size^2)

In my case with input_size = 1 and output_size = 1 this yields only 12 parameters for each of the 10 units, totaling to 120 parameters. Compared to the reported 480, this is off by a factor of 4. Where is my error?

The params formula holds for the whole layer, not per Keras unit.

Quoting this answer :

[In Keras], the unit means the dimension of the inner cells in LSTM.

LSTM in Keras only define exactly one LSTM block, whose cells is of unit-length.

Directly setting output_size = 10 ( like in this comment ) correctly yields the 480 parameters.

Your error lies in the interpretation of terms given on your quoted page (which is admittedly misleading). So n in the reference corresponds to your nb_units , which can be appreciated by the fact that this variable enters quadratically into the given formula, and thus corresponds to the recurrent connectivity, which plays out between the nb_units LSTM cells only.

So, setting output_size = n = 10 in your formula above for params will give the desired 480 parameters.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM