簡體   English   中英

多元LSTM的keras輸入形狀

[英]keras input shape for multivariate LSTM

我正在嘗試在有兩個輸入的喀拉拉邦中擬合LSTM模型

y是形狀為(100,10)的輸出x是形狀為(100,20)的輸入

library(keras)

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 1, 20))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 1, 10))


> dim(x_train_arr)
[1] 100   1  20
> dim(y_train_arr)
[1] 100   1  10

現在我要擬合LSTM模型

model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(1,10), 
             batch_size       = 1) %>% 
  layer_dense(units = 1)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

但是我得到這個錯誤:

py_call_impl中的錯誤(可調用,dots $ args,dots $ keywords):
ValueError:檢查輸入時出錯:預期lstm_21_input具有形狀(1,10),但具有形狀(1,20)的數組

如果將輸入大小更改為c(1,20),則會得到:

py_call_impl中的錯誤(可調用,dots $ args,dots $ keywords):
ValueError:檢查目標時出錯:預期density_13具有2維,但數組的形狀為(100,1,10)

我也使用了不同的設置,但從未奏效。

如果您的Keras版本小於2.0,則需要使用model.add(TimeDistributed(Dense(1)))。

注意該語法是針對python的,您需要找到R等值。

我想出了使它工作的方法:

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 20, 1))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 10))


model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(20,1), 
             batch_size       = 1) %>% 
  layer_dense(units = 10)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM