简体   繁体   中英

How to set custom initial weights to biLSTM layer Keras?

I am currently working on building BiLSTM with Attention with the BiLSTM layer weights being optimised using Antlion Algorithm. The Antlion Alogrithm is in MATLAB code and I am able to integrate Python and MATLAB to receive the optimised weights as seen below:

#LSTM hidden nodes
hidden_nodes=11

import matlab.engine
eng = matlab.engine.start_matlab()
#call optimised_weights.m 
[forward_kernel, backward_kernel,forward_recurrent, backward_recurrent]=eng.optimised_weights(int(hidden_nodes),nargout=4)
eng.quit()

## convert to nparray 
forward_kernel=np.array(forward_kernel)
backward_kernel=np.array(backward_kernel)
forward_recurrent=np.array(forward_recurrent)
backward_recurrent=np.array(backward_recurrent)

I am currently facing issues with setting the weights and biases to the BiLSTM layer as created in the model below (without setting custom initial weights):

class attention(Layer):
    
    def __init__(self, return_sequences=True,**kwargs):
        self.return_sequences = return_sequences
        super(attention,self).__init__()
        
    def build(self, input_shape):
        
        self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                               initializer="normal")
        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                               initializer="zeros")
        
        super(attention,self).build(input_shape)
        
    def call(self, x):
        
        e = K.tanh(K.dot(x,self.W)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        if self.return_sequences:
            return output
        
        return K.sum(output, axis=1)

    def get_config(self):
        # For serialization with 'custom_objects'
        config = super().get_config()
        config['return_sequences'] = self.return_sequences
        return config

model = Sequential()
model.add(Input(shape=(5,1)))
model.add(Bidirectional(LSTM(hidden_nodes, return_sequences=True)))  
model.add(attention(return_sequences=False)) #this is a custom layer...
model.add(Dense(104, activation="sigmoid"))
model.add(Dropout(0.2))
model.add(Dense(1, activation="sigmoid"))

model.compile(optimizer=tf.keras.optimizers.Adam(epsilon=1e-08,learning_rate=0.01),loss='mse')

es = EarlyStopping(monitor='val_loss', mode='min', verbose=2, patience=50)
mc = ModelCheckpoint('model.h5', monitor='val_loss',
                     mode='min', verbose=2, save_best_only=True)

I have tried the following method:

model.add(Bidirectional(LSTM(hidden_nodes, return_sequences=True,
weights=[forward_kernel,forward_recurrent,np.zeros(20,),backward_kernel,backward_recurrent,np.zeros(20,)]))) 

but the weights and biases are changed once the model is compiled...even if the kernel,recurrent and bias initialisers are set to None...

I have referred to this link: https://keras.io/api/layers/initializers/ but could not relate it with my issue...

I would really appreciate if you guys can provide an insight in solving this issue, and if there are any fundamental parts that I have missed. I would be glad to share more details if required.

Thanks!

Use tf.constant_initializer to provide your custom weights as np.array . Also, as you are using a Bidirectional layer, you need to specify the backward layer with your custom weights explicitly.

layer = Bidirectional(
    LSTM(
        hidden_nodes,
        return_sequences=True,
        kernel_initializer=tf.constant_initializer(forward_kernel),
        recurrent_initializer=tf.constant_initializer(forward_recurrent),
    ),
    backward_layer=LSTM(
        hidden_nodes,
        return_sequences=True,
        kernel_initializer=tf.constant_initializer(backward_kernel),
        recurrent_initializer=tf.constant_initializer(backward_recurrent),
        go_backwards=True,
    ),
)

Pay attention to the expected shape of the weights. As the input of the Layer is (batch, timesteps, features) , your weights should have the following shape (to account for the 4 gates in the LSTM cell):

  • kernel: (features, 4*hidden_nodes)
  • recurrent: (hidden_nodes, 4*hidden_nodes)
  • bias: (4*hidden_nodes)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM