I have a time series data set with 3 measurement variables and with about 2000 samples. I want to classify samples into 1 of 4 categories using a RNN or 1D CNN model using Keras in R. My problem is that I am unable to successfully reshape the model the k_reshape()
function.
I am following along the Ch. 6 of Deep Learning with R by Chollet & Allaire, but their examples aren't sufficiently different from my data set that I'm now confused. I've tried to mimic the code from that chapter of the book to no avail. Here's a link to the source code for the chapter.
library(keras)
df <- data.frame()
for (i in c(1:20)) {
time <- c(1:100)
var1 <- runif(100)
var2 <- runif(100)
var3 <- runif(100)
run <- data.frame(time, var1, var2, var3)
run$sample <- i
run$class <- sample(c(1:4), 1)
df <- rbind(df, run)
}
head(df)
# time feature1 feature2 feature3 sample class
# 1 0.4168828 0.1152874 0.0004415961 1 4
# 2 0.7872770 0.2869975 0.8809415097 1 4
# 3 0.7361959 0.5528836 0.7201276931 1 4
# 4 0.6991283 0.1019354 0.8873193581 1 4
# 5 0.8900918 0.6512922 0.3656302236 1 4
# 6 0.6262068 0.1773450 0.3722923032 1 4
k_reshape(df, shape(10, 100, 3))
# Error in py_call_impl(callable, dots$args, dots$keywords) :
# TypeError: Failed to convert object of type <class 'dict'> to Tensor. Contents: {'time': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 3
I'm very new to reshaping arrays, but I would like to have an array with the shape: (samples, time, features)
. I would love to hear suggestions on how to properly reshape this array or guidance on how this data should be treated for a DL model if I'm off basis on that front.
I found two solutions to my question. My confusion stemmed from the error message from k_reshape
that I did not understand how to interpret.
array_reshape()
function from the reticulate package. k_reshape()
function from keras but this time use the appropriate shape. Here is the code I successfully executed:
# generate data frame
dat <- data.frame()
for (i in c(1:20)) {
time <- c(1:100)
var1 <- runif(100)
var2 <- runif(100)
var3 <- runif(100)
run <- data.frame(time, var1, var2, var3)
run$sample <- i
run$class <- sample(c(1:4), 1)
dat <- rbind(df, run)
}
dat_m <- as.matrix(df) # convert data frame to matrix
# time feature1 feature2 feature3 sample class
# 1 0.4168828 0.1152874 0.0004415961 1 4
# 2 0.7872770 0.2869975 0.8809415097 1 4
# 3 0.7361959 0.5528836 0.7201276931 1 4
# 4 0.6991283 0.1019354 0.8873193581 1 4
# 5 0.8900918 0.6512922 0.3656302236 1 4
# 6 0.6262068 0.1773450 0.3722923032 1 4
# solution with reticulate's array_reshape function
dat_array <- reticulate::array_reshape(x = dat_m[,c(2:4)], dim = c(20, 100, 3))
dim(dat_array)
# [1] 20 100 3
class(dat_array)
# [1] "array"
# solution with keras's k_reshape
dat_array_2 <- keras::k_reshape(x = dat_m[,c(2:4)], shape = c(20, 100, 3))
dim(dat_array)
# [1] 20 100 3
class(dat_array)
# [1] 20 100 3
class(dat_array_2)
# [1] "tensorflow.tensor" "tensorflow.python.framework.ops.Tensor"
# [3] "tensorflow.python.framework.ops._TensorLike" "python.builtin.object"
A few notes:
array_reshape
is an array class, but k_reshape()
outputs a tensorflow tensor object. Both worked for me in created deep learning networks, but I find the array class much more interpretable.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.