简体   繁体   中英

How to add dropout and attention in LSTM in kers in python

I have about 1000 nodes dataset where each node has 4 time-series. Each time series is exactly 6 length long.The label is 0 or 1 (ie binary classification).

More precisely my dataset looks as follows.

node, time-series1, time_series2, time_series_3, time_series4, Label
n1, [1.2, 2.5, 3.7, 4.2, 5.6, 8.8], [6.2, 5.5, 4.7, 3.2, 2.6, 1.8], …, 1
n2, [5.2, 4.5, 3.7, 2.2, 1.6, 0.8], [8.2, 7.5, 6.7, 5.2, 4.6, 1.8], …, 0
and so on.

I normalise my timeseries before I feed it into my LSTM model for classification.

model = Sequential()
model.add(LSTM(10, input_shape=(6,4)))
model.add(Dense(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

print(data.shape) # (1000, 6, 4)
model.fit(data, target)

I am new to keras and that is why started with the simplest LSTM model. However, now I would like to make it into a level that I can use it at an industry level.

I read that it is good to add dropout and attention layers to the LSTM models. Please let me know if you think that adding such layers is applicable to my problem and if so how to do it? :)

Note: I am not limited to droupout and attention layers and happy to receive other suggestions that I can use to improve my model.

I am happy to provide more details if needed.

if you want to add dropout in lstm cell, you can try this

model = Sequential()
model.add(LSTM(10, input_shape=(6,4), dropout=0.5))
model.add(Dense(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

print(data.shape) # (1000, 6, 4)
model.fit(data, target)

or using dropout between lstm cell, may consider below

model = Sequential()
model.add(LSTM(10, input_shape=(6,4)))
model.add(Dropout(0.5))
model.add(LSTM(10, input_shape=(6,4)))
model.add(Dense(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

print(data.shape) # (1000, 6, 4)
model.fit(data, target)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM