繁体   English   中英

多元LSTM的keras输入形状

[英]keras input shape for multivariate LSTM

我正在尝试在有两个输入的喀拉拉邦中拟合LSTM模型

y是形状为(100,10)的输出x是形状为(100,20)的输入

library(keras)

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 1, 20))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 1, 10))


> dim(x_train_arr)
[1] 100   1  20
> dim(y_train_arr)
[1] 100   1  10

现在我要拟合LSTM模型

model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(1,10), 
             batch_size       = 1) %>% 
  layer_dense(units = 1)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

但是我得到这个错误:

py_call_impl中的错误(可调用,dots $ args,dots $ keywords):
ValueError:检查输入时出错:预期lstm_21_input具有形状(1,10),但具有形状(1,20)的数组

如果将输入大小更改为c(1,20),则会得到:

py_call_impl中的错误(可调用,dots $ args,dots $ keywords):
ValueError:检查目标时出错:预期density_13具有2维,但数组的形状为(100,1,10)

我也使用了不同的设置,但从未奏效。

如果您的Keras版本小于2.0,则需要使用model.add(TimeDistributed(Dense(1)))。

注意该语法是针对python的,您需要找到R等值。

我想出了使它工作的方法:

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 20, 1))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 10))


model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(20,1), 
             batch_size       = 1) %>% 
  layer_dense(units = 10)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM