繁体   English   中英

在 Keras Tensorflow 中将 TimeseriesGenerator 与多元数据集结合使用

[英]Using TimeseriesGenerator with a multivariate dataset in Keras Tensorflow

我正在尝试对 Keras 中 TimeseriesGenerator 的输出进行建模,该输出将用作 LSTM 网络的输入,但一直面临问题。 数据集具有以下结构:

在此处输入图片说明

其中特征集显示为绿色(F1 到 F6),目标变量 (T) 显示为红色。 我将包含 3170 个观测值的总数据集划分为三组:

在此处输入图片说明

由于 Keras 中的 LSTM 需要三个维度的输入大小,因此我使用以下命令重塑了数据集:

        train= train.reshape((train_df.shape[0], 1, train_df.shape[1]))
        validation= validation.reshape((validation.shape[0], 1, validation.shape[1]))
        test= test.reshape((test.shape[0], 1, test.shape[1]))

因此,重构数据集的大小如下:

在此处输入图片说明

其中三个维度是(样本、时间步长、特征)。 但真正的问题是现在将数据集传递给 keras 中的时间序列生成器。 使用的生成器代码如下:

        generator = TimeseriesGenerator(train, train_target, length=1, batch_size=10)

TimeseriesGenerator 将数据集传递给 fit_generator,如下所示:

        model.fit_generator(generator, validation_data=(validation, validation_target),
                                       epochs=100, verbose=0,
                                       shuffle=False, workers=1, use_multiprocessing=True)

而我在 Keras 中的 LSTM 网络配置如下:

                model = Sequential()
                model.add(LSTM(200, input_shape=(10, 6), return_sequences=True))
                model.add(LSTM(200, input_shape=(10, 6), return_sequences=False))
                model.add(Dense(1, kernel_initializer='uniform', activation='linear')) 

第一个 LSTM 层的 input_shape 是(10,6) ,这意味着 10 个具有 6 个特征的样本/观察。 我选择 (10,6) 的 input_shape 是因为 TimeseriesGenerator 应该生成 10 个样本的 batch_size,每个样本具有 6 个特征。

但是,这会导致如下错误:

ValueError: Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (10, 1, 1, 6)

TimeseriesGenerator 生成训练集的输入大小为(10, 1, 1, 6) 生成的训练数据集有四个维度,但是,我希望 TimeseriesGenerator 生成 10 个样本的 batch_size,每个样本具有 6 个特征,即输入大小为(10,1,6)的训练数据集。

如何让 TimeseriesGenerator 生成(10,1,6)的输入大小?

错误是...不要重塑您的训练、验证和测试数据...生成器会自动将其设为 (10, 1,6)..因为当您指定批次大小为 10 时,它将需要 10 个批次.. . 长度为 1 个数据和 6 个特征的...试试吧。 它会工作..

只需在 (m, n) 生成器中提供您的训练、验证、测试数据即可自行重塑

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM