[英]Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (5, 10)
[英]Tensorflow/keras error: ValueError: Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (4012, 42)
I have a pandas dataframe that is being made by a train_test_split called x_train with 4012 rows all the values in the dataframe are either int's or floats (after train/test split) and 42 columns (after train/test split). 我正在尝试使用 LSTM 单元训练递归神经网络,但是当 model 尝试输入数据帧时,我的程序一直给我一个错误。 我试过重塑 dataframe,这也不起作用,但也许我做错了。 如果您知道如何修复它和/或知道如何使我的 model 更有效,请告诉我。
这是我的代码:
df = pd.read_csv("data.csv")
x = df.loc[:, df.columns != 'result']
y = df.loc[:, df.columns == 'result']
y = y.astype(int)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle = False)
model = Sequential()
model.add(LSTM(128, input_shape = (4012, 42), activation = 'relu', return_sequences = True)) #This is the line where the error happens.
model.add(Dropout(0.2))
model.add(LSTM(128, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(32, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(2, activation = 'softmax'))
opt = tf.keras.optimizers.Adam(lr = 10e-3, decay=1e-5)
model.compile(loss = 'sparse_categorical_crossentropy', optimizer = opt, metrics = ['accuracy'])
model.fit(x_train, y_train, epochs = 5, validation_data = (x_test, y_test))
我得到的错误:
ValueError:检查输入时出错:预期 lstm_input 有 3 个维度,但得到了形状为 (4012、42) 的数组
有人请帮助我。
问题是您正在尝试为 model 提供二维数组,但它需要 3 维数组。 而不是重塑 Dataframe 将其转换为数组,然后根据下面的修改代码重塑。
df = pd.read_csv("data.csv")
x = df.loc[:, df.columns != 'result']
y = df.loc[:, df.columns == 'result']
y = y.astype(int)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle = False)
x_train = x_train.values.reshape((1, x_train.shape[0], x_train.shape[1]))
x_test = x_test.values.reshape((1, x_test.shape[0], x_test.shape[1]))
y_train = y_train.values.reshape((1, y_train.shape[0], y_train.shape[1]))
y_test = y_test.values.reshape((1, y_test.shape[0], y_test.shape[1]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.