I'm trying to fit a LSTM model in keras where I have two inputs
y
is the output with shape (100,10) x
is the input with shape (100,20)
library(keras)
x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 1, 20))
y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 1, 10))
> dim(x_train_arr)
[1] 100 1 20
> dim(y_train_arr)
[1] 100 1 10
Now I want to fit the LSTM model
model <- keras_model_sequential()
model %>%
layer_lstm(units = 50,
input_shape = c(1,10),
batch_size = 1) %>%
layer_dense(units = 1)
model %>%
compile(loss = 'mae', optimizer = 'adam')
model %>% fit(x = x_train_arr,
y = y_train_arr,
batch_size = 1,
epochs = 10,
verbose = 1,
shuffle = FALSE)
But I get this error:
Error in py_call_impl(callable, dots$args, dots$keywords) :
ValueError: Error when checking input: expected lstm_21_input to have shape (1, 10) but got array with shape (1, 20)
If I change input size to c(1,20), I get:
Error in py_call_impl(callable, dots$args, dots$keywords) :
ValueError: Error when checking target: expected dense_13 to have 2 dimensions, but got array with shape (100, 1, 10)
I also played with different setting but it never works.
IF your Keras version is < 2.0 you need to use model.add(TimeDistributed(Dense(1))).
NOTE that syntax is for python, you need to find the R equivealent.
I figured out how to make it work:
x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 20, 1))
y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 10))
model <- keras_model_sequential()
model %>%
layer_lstm(units = 50,
input_shape = c(20,1),
batch_size = 1) %>%
layer_dense(units = 10)
model %>%
compile(loss = 'mae', optimizer = 'adam')
model %>% fit(x = x_train_arr,
y = y_train_arr,
batch_size = 1,
epochs = 10,
verbose = 1,
shuffle = FALSE)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.