简体   繁体   中英

How to fit a LSTM model using tf.keras

I have an error while trying to fit a model that i can't identify. KeyError: 168 keras-2.6.0. tensorflow-2.6.0. My data is after onehotencoding and has been standardized with minmaxscaler. I don't have any NaN or inf values. My code sample below:

x_train,x_test,y_train,y_test=train_test_split(features,target,test_size=0.2,random_state=123,shuffle=False)
from keras.preprocessing.sequence import TimeseriesGenerator
win_length = 168
batch_size = 32
num_features = 40
train_generator = TimeseriesGenerator(x_train,y_train,length=win_length,sampling_rate=1,batch_size=batch_size)
test_generator = TimeseriesGenerator(x_test,y_test,length=win_length,sampling_rate=1,batch_size=batch_size)
import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(128,input_shape=(win_length,num_features),return_sequences=True))
model.add(tf.keras.layers.LeakyReLU(alpha=0.5))
model.add(tf.keras.layers.LSTM(128,return_sequences=True))
model.add(tf.keras.layers.LeakyReLU(alpha=0.5))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.LSTM(64,return_sequences=False))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(1))
model.summary()

Model: "sequential_2"


Layer (type) Output Shape Param #


lstm_6 (LSTM) (None, 168, 128) 86528


leaky_re_lu_4 (LeakyReLU) (None, 168, 128) 0


lstm_7 (LSTM) (None, 168, 128) 131584


leaky_re_lu_5 (LeakyReLU) (None, 168, 128) 0


dropout_4 (Dropout) (None, 168, 128) 0


lstm_8 (LSTM) (None, 64) 49408


dropout_5 (Dropout) (None, 64) 0


dense_2 (Dense) (None, 1) 65


Total params: 267,585 Trainable params: 267,585 Non-trainable params: 0

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, mode='min')
model.compile(loss=tf.losses.MeanSquaredError(),
              optimizer=tf.optimizers.Adam(),
              metrics=[tf.metrics.MeanAbsoluteError()])
history = model.fit(train_generator,epochs=50,validation_data=test_generator,shuffle=False,callbacks=[early_stopping])

And then I got error:

Epoch 1/50
136/136 [==============================] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0243
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/anaconda3/envs/python3/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2897             try:
-> 2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 168

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
<ipython-input-86-b78994610c64> in <module>
      3               optimizer=tf.optimizers.Adam(),
      4               metrics=[tf.metrics.MeanAbsoluteError()])
----> 5 history = model.fit(train_generator,epochs=50,validation_data=test_generator,shuffle=False,callbacks=[early_stopping])

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1212                 use_multiprocessing=use_multiprocessing,
   1213                 model=self,
-> 1214                 steps_per_execution=self._steps_per_execution)
   1215           val_logs = self.evaluate(
   1216               x=val_x,

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/data_adapter.py in get_data_handler(*args, **kwargs)
   1381   if getattr(kwargs["model"], "_cluster_coordinator", None):
   1382     return _ClusterCoordinatorDataHandler(*args, **kwargs)
-> 1383   return DataHandler(*args, **kwargs)
   1384 
   1385 

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/data_adapter.py in __init__(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution, distribute)
   1148         use_multiprocessing=use_multiprocessing,
   1149         distribution_strategy=tf.distribute.get_strategy(),
-> 1150         model=model)
   1151 
   1152     strategy = tf.distribute.get_strategy()

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, shuffle, workers, use_multiprocessing, max_queue_size, model, **kwargs)
    922         max_queue_size=max_queue_size,
    923         model=model,
--> 924         **kwargs)
    925 
    926   @staticmethod

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, workers, use_multiprocessing, max_queue_size, model, **kwargs)
    792     # Since we have to know the dtype of the python generator when we build the
    793     # dataset, we have to look at a batch to infer the structure.
--> 794     peek, x = self._peek_and_restore(x)
    795     peek = self._standardize_batch(peek)
    796     peek = _process_tensorlike(peek)

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras/engine/data_adapter.py in _peek_and_restore(x)
    926   @staticmethod
    927   def _peek_and_restore(x):
--> 928     return x[0], x
    929 
    930   def _handle_multiprocessing(self, x, workers, use_multiprocessing,

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras_preprocessing/sequence.py in __getitem__(self, index)
    372         samples = np.array([self.data[row - self.length:row:self.sampling_rate]
    373                             for row in rows])
--> 374         targets = np.array([self.targets[row] for row in rows])
    375 
    376         if self.reverse:

~/anaconda3/envs/python3/lib/python3.6/site-packages/keras_preprocessing/sequence.py in <listcomp>(.0)
    372         samples = np.array([self.data[row - self.length:row:self.sampling_rate]
    373                             for row in rows])
--> 374         targets = np.array([self.targets[row] for row in rows])
    375 
    376         if self.reverse:

~/anaconda3/envs/python3/lib/python3.6/site-packages/pandas/core/series.py in __getitem__(self, key)
    880 
    881         elif key_is_scalar:
--> 882             return self._get_value(key)
    883 
    884         if is_hashable(key):

~/anaconda3/envs/python3/lib/python3.6/site-packages/pandas/core/series.py in _get_value(self, label, takeable)
    988 
    989         # Similar to Index.get_value, but we do not fall back to positional
--> 990         loc = self.index.get_loc(label)
    991         return self.index._get_values_for_loc(self, loc, label)
    992 

~/anaconda3/envs/python3/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:
-> 2900                 raise KeyError(key) from err
   2901 
   2902         if tolerance is not None:

KeyError: 168

My data is aggregated by the hour and 168 is the whole week. How can I fit my model properly?

My best guess without the data is that it is retrieving the target value using the index of the y_test which looks like a pd.Series . After the split, y_test probably still has the old index values from target because y_test is just a view into a portion of target . If I'm right, the first/smallest index value in your y_test series is probably ~17,512. You can test that hypothesis with:

print(y_test.index.min())

When it goes to pick 168, it doesn't find it because the indexes start ~17K higher. If that's the case, this should fix it:

y_test = y_test.copy().reset_index(drop=True)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM