繁体   English   中英

Conv1D 层 Keras 的 input_shape

[英]input_shape of Conv1D layer Keras

我正在尝试制作一个 CNN model 用于非图像数据集的二进制分类。 我的模型/代码正在运行并产生非常好的结果(精度很高),但我无法理解 Conv1D 第一层的input_shape Conv1D

X 或输入(此处为x_train_df )的形状为 (2000, 28)。 它有 28 个特征和 2000 个样本。 Y 或标签的形状(此处为y_train_df )为 (2000, 1)。

model = Sequential()
model.add(Conv1D(filters = 64, kernel_size = 3, activation = 'relu', input_shape = (x_train_df.shape[1], 1)))
model.add(Conv1D(filters = 64, kernel_size = 3, activation = 'relu'))
model.add(MaxPooling1D(pool_size = 2))
model.add(Flatten())
model.add(Dense(100, activation = 'relu'))
model.add(Dense(1, activation = 'sigmoid'))

optimzr = Adam(learning_rate=0.005)
model.compile(loss='binary_crossentropy', optimizer=optimzr,  metrics=[[tf.keras.metrics.AUC(curve="ROC", name = 'auc')], [tf.keras.metrics.AUC(curve="PR", name = 'pr')]])

# running the fitting
model.fit(x_train_df, y_train_df, epochs = 2, batch_size = 32, validation_data = (x_val_df, y_val_df), verbose = 2)

我已将input_shape为 (28, 1) (参考this question )。

但是在Conv1D 层文档中写道,

当将此层用作 model 中的第一层时,为128 维向量的 10 个向量的序列提供一个 input_shape 参数(整数元组或无,例如 (10, 128) 。

我从中了解到的是 input_shape 的维度应该是 (2000, 1) 因为我有 2000 个一维向量。 但是将此作为input_shape显示为错误,

ValueError:层“sequential_25”的输入 0 与层不兼容:预期形状 =(无,2000,1),找到形状 =(无,28)

所以我的问题是正确的input_shape应该是什么?

让我们看看“Conv1D”是如何接受输入的。

>>> # The inputs are 128-length vectors with 10 timesteps, and the batch size
>>> # is 4.
>>> input_shape = (4, 10, 128)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv1D(
... 32, 3, activation='relu',input_shape=input_shape[1:])(x)
>>> print(y.shape)
    (4, 8, 32)

具有形状的 3+D 张量:batch_shape + (steps, input_dim)

如上所示,有 128 个特征、10 个时间步长和 4 个批量大小。因此,Conv1D 将输入作为 (batch_size,timesteps,features)。 它需要 3D 输入。 假设您为您的案例选择批量大小为 1。 你必须提供像 (1,2000,28) 这样的输入。

我也将 CNN-Conv1D 用于非图像数据集。 我有 188 个特征。 所以我对我的输入这样做: X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1],1))

#在我的初始化RNN中

model.add(重塑((188, 1), input_shape(X_train[1], 1))

所以输入形状 (None, 188, 1)

这对我有用,我得到了 97% 的训练和 91% 的用真实数据测试数据。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM