繁体   English   中英

Tensorflow/keras 错误:ValueError:检查输入时出错:预期 lstm_input 具有 3 个维度,但得到的数组具有形状(4012、42)

[英]Tensorflow/keras error: ValueError: Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (4012, 42)

I have a pandas dataframe that is being made by a train_test_split called x_train with 4012 rows all the values in the dataframe are either int's or floats (after train/test split) and 42 columns (after train/test split). 我正在尝试使用 LSTM 单元训练递归神经网络,但是当 model 尝试输入数据帧时,我的程序一直给我一个错误。 我试过重塑 dataframe,这也不起作用,但也许我做错了。 如果您知道如何修复它和/或知道如何使我的 model 更有效,请告诉我。

这是我的代码:

df = pd.read_csv("data.csv")

x = df.loc[:, df.columns != 'result']
y = df.loc[:, df.columns == 'result']

y = y.astype(int)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle = False)

model = Sequential()

model.add(LSTM(128, input_shape = (4012, 42), activation = 'relu', return_sequences = True)) #This is the line where the error happens.
model.add(Dropout(0.2))

model.add(LSTM(128, activation = 'relu'))
model.add(Dropout(0.2))

model.add(Dense(32, activation = 'relu'))
model.add(Dropout(0.2))

model.add(Dense(2, activation = 'softmax'))

opt = tf.keras.optimizers.Adam(lr = 10e-3, decay=1e-5)

model.compile(loss = 'sparse_categorical_crossentropy', optimizer = opt, metrics = ['accuracy'])

model.fit(x_train, y_train, epochs = 5, validation_data = (x_test, y_test))

我得到的错误:

ValueError:检查输入时出错:预期 lstm_input 有 3 个维度,但得到了形状为 (4012、42) 的数组

有人请帮助我。

问题是您正在尝试为 model 提供二维数组,但它需要 3 维数组。 而不是重塑 Dataframe 将其转换为数组,然后根据下面的修改代码重塑。

df = pd.read_csv("data.csv")

x = df.loc[:, df.columns != 'result']
y = df.loc[:, df.columns == 'result']

y = y.astype(int)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle = False)
x_train = x_train.values.reshape((1, x_train.shape[0], x_train.shape[1]))
x_test = x_test.values.reshape((1, x_test.shape[0], x_test.shape[1]))



y_train = y_train.values.reshape((1, y_train.shape[0], y_train.shape[1]))
y_test = y_test.values.reshape((1, y_test.shape[0], y_test.shape[1]))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM