简体   繁体   中英

Wrap Keras LSTM layers in a function the correct way

I'm constructing an LSTM that has a fairly verbose layer construction thanks to an ongoing hyperparameter, so I wanted to wrap the creation of each layer in a function. However, the output behavior is not what I'm expecting.

This is my complete function (there are some helper functions I've left undefined that basically return functions from strings). I'm putting the hyperparameter suggestions for the first trial in comments next to the inputs to help you understand how the function is being called:

 def _get_model(
  self, 
  encoder_width: int,  # 513
  decoder_width: int,  # 513
  input_seq_len: int,  # 12
  output_seq_len: int,  # 1
  n_features_input: int, # x_train.shape[1]
  n_features_output: int, #y_train.shape[1]
  num_encoder_layers: int, #6 
  num_decoder_layers: int, # 6
  dropout: float, #.25
  recurrent_dropout: float, #.25
  recurrent_regularizer: str, 
  kernel_regularizer: str,
  activation: str,
  recurrent_activation: str,
  ) -> tf.keras.models.Model:
  """Configures the regression model.

  Args:
    encoder_width: number of hidden units on the encoder layer
    decoder_width: number of hidden units on the decoder layer
    input_seq_len: number of past sequential observations as input
    output_seq_len: number of future outputs to predict, generally 1
    n_features_input: number of input features/dims
    n_features_output: number of output features/dims
    num_encoder_layers: number of layers to use in encoder
    num_decoder_layers: number of layers to use in the decoder
    dropout: the proportion of nodes to dropout during training
    recurrent_dropout: proportion of nodes to dropout for recurrent state
    recurrent_regularizer: regularizer to put on the recurrent_kernel weights matrix
    kernel_regularizer: regularizer fuction applied to kernel weights matrix
    activation: activation function to use
    recurrent_activation: activation function to use for the recurrent step

  Returns:
    a keras LSTM model
  """

  tf.keras.backend.clear_session()

  recurrent_regularizer = self._get_keras_regularizer_from_str(recurrent_regularizer)
  recurrent_activation = self._get_keras_activation_from_str(recurrent_activation)
  kernel_regularizer = self._get_keras_regularizer_from_str(kernel_regularizer)
  activation = self._get_keras_activation_from_str(activation)

  # we define a local LSTM layer to keep the code tidy
  def LSTM_layer(x: tf.Tensor, width:int, 
                 return_sequences: bool, initial_state=None):
    
    encoder_layer = tf.keras.layers.LSTM(
        width, return_state=True, return_sequences=return_sequences,
        dropout=dropout, recurrent_dropout=recurrent_dropout, 
        recurrent_regularizer=recurrent_regularizer, 
        kernel_regularizer=kernel_regularizer,
        activation=activation, recurrent_activation=recurrent_activation)
    x, state_h, state_c = encoder_layer(x, initial_state=initial_state)
    encoder_states = [state_h, state_c]

    return x, encoder_states

  x = tf.keras.layers.Input(
      shape=(input_seq_len, n_features_input))
  
  if num_encoder_layers==1:
    x, encoder_states_1 = LSTM_layer(x, encoder_width, return_sequences = False)
  else: 
    temp_encoder_layers = num_encoder_layers

    # we need to pass return_sequences to every succeeding LSTM layer
    while temp_encoder_layers > 1:
      x, encoder_states = LSTM_layer(x, encoder_width, return_sequences=True)
      
      # we would like to keep the first encoder state to initialize all decoder states
      if temp_encoder_layers == num_encoder_layers:
        encoder_states_1 = encoder_states
      temp_encoder_layers -= 1

    # we want the final layer not to return sequences for the RepeatVector
    x, encoder_states = LSTM_layer(x, encoder_width, return_sequences=False)

  # we need the repeat layer to separate the encoder and decoder 
  x = tf.keras.layers.RepeatVector(output_seq_len)(x) 

  for i in range(num_decoder_layers):
    x, encoder_states = LSTM_layer(x, decoder_width, 
                   return_sequences=True)#, initial_state=encoder_states_1)

  decoder_outputs2 = tf.keras.layers.TimeDistributed(
      tf.keras.layers.Dense(n_features_output))(x)

  return tf.keras.models.Model(x, decoder_outputs2)

As you can see, there's the LSTM_layer function which I believe might be causing trouble.

Then, when ran with data, this throws the error:

ValueError: Input 0 of layer "model" is incompatible with the layer: expected shape=(None, 1, 513), found shape=(514, 12, 109)

The shape of x_train is:

 X_train shape
(1824, 12, 109)

And the model architecture is:

    Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 1, 513)]          0         
                                                                 
 time_distributed (TimeDistr  (None, 1, 80)            41120     
 ibuted)                                                         
                                                                 
=================================================================
Total params: 41,120
Trainable params: 41,120
Non-trainable params: 0
_________________________________________________________________

So, I know that the model must be getting confused as the shape of x_train seems to be getting passed to the very first layer, and a full 12 layers are certainly not being constructed. Am I doing this the wrong way then? Is some Pythonic thing happening where it's assigning the same layer over and over again each time the function is called?

Any help or advice much appreciated. Thank you!

Answer: the problem wasn't in the wrapping of the function. The problem was in the final definition of the model, tf.keras.models.Model(x, decoder_outputs2), that calls x as the first parameter. Since x is redefined so many times, the model is simply pointed to the last occurrence of x. So the model becomes a link from the last LSTM decoder layer to the dense layer only. RIP.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM