簡體   English   中英

在 Keras Tensorflow 中將 TimeseriesGenerator 與多元數據集結合使用

[英]Using TimeseriesGenerator with a multivariate dataset in Keras Tensorflow

我正在嘗試對 Keras 中 TimeseriesGenerator 的輸出進行建模,該輸出將用作 LSTM 網絡的輸入,但一直面臨問題。 數據集具有以下結構:

在此處輸入圖片說明

其中特征集顯示為綠色(F1 到 F6),目標變量 (T) 顯示為紅色。 我將包含 3170 個觀測值的總數據集划分為三組:

在此處輸入圖片說明

由於 Keras 中的 LSTM 需要三個維度的輸入大小,因此我使用以下命令重塑了數據集:

        train= train.reshape((train_df.shape[0], 1, train_df.shape[1]))
        validation= validation.reshape((validation.shape[0], 1, validation.shape[1]))
        test= test.reshape((test.shape[0], 1, test.shape[1]))

因此,重構數據集的大小如下:

在此處輸入圖片說明

其中三個維度是(樣本、時間步長、特征)。 但真正的問題是現在將數據集傳遞給 keras 中的時間序列生成器。 使用的生成器代碼如下:

        generator = TimeseriesGenerator(train, train_target, length=1, batch_size=10)

TimeseriesGenerator 將數據集傳遞給 fit_generator,如下所示:

        model.fit_generator(generator, validation_data=(validation, validation_target),
                                       epochs=100, verbose=0,
                                       shuffle=False, workers=1, use_multiprocessing=True)

而我在 Keras 中的 LSTM 網絡配置如下:

                model = Sequential()
                model.add(LSTM(200, input_shape=(10, 6), return_sequences=True))
                model.add(LSTM(200, input_shape=(10, 6), return_sequences=False))
                model.add(Dense(1, kernel_initializer='uniform', activation='linear')) 

第一個 LSTM 層的 input_shape 是(10,6) ,這意味着 10 個具有 6 個特征的樣本/觀察。 我選擇 (10,6) 的 input_shape 是因為 TimeseriesGenerator 應該生成 10 個樣本的 batch_size,每個樣本具有 6 個特征。

但是,這會導致如下錯誤:

ValueError: Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (10, 1, 1, 6)

TimeseriesGenerator 生成訓練集的輸入大小為(10, 1, 1, 6) 生成的訓練數據集有四個維度,但是,我希望 TimeseriesGenerator 生成 10 個樣本的 batch_size,每個樣本具有 6 個特征,即輸入大小為(10,1,6)的訓練數據集。

如何讓 TimeseriesGenerator 生成(10,1,6)的輸入大小?

錯誤是...不要重塑您的訓練、驗證和測試數據...生成器會自動將其設為 (10, 1,6)..因為當您指定批次大小為 10 時,它將需要 10 個批次.. . 長度為 1 個數據和 6 個特征的...試試吧。 它會工作..

只需在 (m, n) 生成器中提供您的訓練、驗證、測試數據即可自行重塑

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM