简体   繁体   English

多元LSTM的keras输入形状

[英]keras input shape for multivariate LSTM

I'm trying to fit a LSTM model in keras where I have two inputs 我正在尝试在有两个输入的喀拉拉邦中拟合LSTM模型

y is the output with shape (100,10) x is the input with shape (100,20) y是形状为(100,10)的输出x是形状为(100,20)的输入

library(keras)

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 1, 20))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 1, 10))


> dim(x_train_arr)
[1] 100   1  20
> dim(y_train_arr)
[1] 100   1  10

Now I want to fit the LSTM model 现在我要拟合LSTM模型

model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(1,10), 
             batch_size       = 1) %>% 
  layer_dense(units = 1)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

But I get this error: 但是我得到这个错误:

Error in py_call_impl(callable, dots$args, dots$keywords) : py_call_impl中的错误(可调用,dots $ args,dots $ keywords):
ValueError: Error when checking input: expected lstm_21_input to have shape (1, 10) but got array with shape (1, 20) ValueError:检查输入时出错:预期lstm_21_input具有形状(1,10),但具有形状(1,20)的数组

If I change input size to c(1,20), I get: 如果将输入大小更改为c(1,20),则会得到:

Error in py_call_impl(callable, dots$args, dots$keywords) : py_call_impl中的错误(可调用,dots $ args,dots $ keywords):
ValueError: Error when checking target: expected dense_13 to have 2 dimensions, but got array with shape (100, 1, 10) ValueError:检查目标时出错:预期density_13具有2维,但数组的形状为(100,1,10)

I also played with different setting but it never works. 我也使用了不同的设置,但从未奏效。

IF your Keras version is < 2.0 you need to use model.add(TimeDistributed(Dense(1))). 如果您的Keras版本小于2.0,则需要使用model.add(TimeDistributed(Dense(1)))。

NOTE that syntax is for python, you need to find the R equivealent. 注意该语法是针对python的,您需要找到R等值。

I figured out how to make it work: 我想出了使它工作的方法:

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 20, 1))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 10))


model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(20,1), 
             batch_size       = 1) %>% 
  layer_dense(units = 10)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 正确的kerasR中LSTM的input_shape - Correct input_shape for an LSTM in kerasR ValueError:输入 0 与层 lstm_2 不兼容:预期 ndim=3,发现 ndim=4 - 多元时间序列数据 - ValueError: Input 0 is incompatible with layer lstm_2: expected ndim=3, found ndim=4 - multivariate timeseries data 输入形状(以喀拉斯计)(此损失期望目标与输出具有相同的形状) - Input shape in keras (This loss expects targets to have the same shape as the output) Keras LSTM和多输入功能:如何定义参数 - Keras LSTM and multiple input feature: how to define parameters 张量形状在keras模型的输入形状中重要吗? R编程 - Does Tensor Shape matter in the input shape of keras model? R-Programming 使用keras 1D卷积层在R(Rstudio)中设置NLP任务的输入形状,当它需要3维输入(张量)时 - Setting input shape for an NLP task in R(Rstudio) using keras 1D convolution layer, when it expects 3 dimensional input (a tensor) 提供数据框作为多元函数的输入 - Supply data frame as input to multivariate function Keras R中图像分类模型中的形状误差 - Shape error in image classification model in Keras R 如何使用Keras IN R实现简单,基本的多步LSTM? - How to implement a simple and basic multi step LSTM with Keras IN R? LSTM Keras中的生成器功能可输出一个文件的小批量 - Generator function in LSTM Keras for outputting mini batches of one files
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM