[英]Input shape for CNN is incompatible
我想训练一个模型,从生理信号中预测一个人的情绪。 我有三个物理信号并将其用作输入功能。
心电图(心电图),gsr(皮肤电反应),温度(温度)
在我的数据集中,每个参与者有312条记录,每条记录中都有18000行数据。 因此,当我将它们组合到一个数据框中时,总共有566000行。
这是我的x_train
数据x_train
;
ecg gsr temp
0 0.1912 0.0000 40.10
1 0.3597 0.0000 40.26
2 0.3597 0.0000 40.20
3 0.3597 0.0000 40.20
4 0.3597 0.0000 40.33
5 0.3597 0.0000 40.03
6 0.2739 0.0039 40.13
7 0.1641 0.0031 40.20
8 0.0776 0.0025 40.20
9 0.0005 0.0020 40.26
10 -0.0375 0.0016 40.03
11 -0.0676 0.0013 40.16
12 -0.1071 0.0010 40.20
13 -0.1197 0.0047 40.20
.. ....... ...... .....
.. ....... ...... .....
.. ....... ...... .....
5616000 0.0226 0.1803 38.43
我有6个与情感相对应的课程。 我已经用数字对这些标签进行了编码。
愤怒= 0,平静= 1,厌恶= 2,恐惧= 3,幸福= 4,悲伤= 5
这是我的y_train;
emotion
0 0
1 0
2 0
3 0
4 0
. .
. .
. .
18001 1
18002 1
18003 1
. .
. .
. .
360001 2
360002 2
360003 2
. .
. .
. .
. .
5616000 5
为了提供CNN模型,我需要重塑我的火车示例。 我是这样做的;
train_x = train_x.values.reshape(5616000,3,1) #because I have 5616000 rows and 3 input features
train_y = train_y.values.reshape(5616000,1)
重塑后,我创建了CNN模型;
model = Sequential()
model.add(Conv1D(100,700,activation='relu',input_shape=(5616000,3)))
model.add(Conv1D(100,700,activation='relu'))
model.add(MaxPooling1D(4))
model.add(Conv1D(160,700,activation='relu'))
model.add(Conv1D(160,700,activation='relu'))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.5))
model.add(Dense(1,activation='softmax'))
model.compile(optimizer = sgd, loss = 'binary_crossentropy', metrics = ['acc'])
model.fit(train_x,train_y,epochs = 300, batch_size = 32, validation_split=0.33, shuffle=False)
这给了我以下错误;
ValueError:检查输入时出错:预期conv1d_96_input具有形状(5616000,3),但具有形状(3,1)的数组
无论我尝试了什么,我都无法使它起作用。 任何帮助表示赞赏,谢谢。
问题出在model.add(Conv1D(100,700,activation='relu',input_shape=(5616000,3)))
。 因为这里input_shape
为(3,1)
,所以您有3个输入input_shape
。
model.add(Conv1D(100,700,activation='relu',input_shape=(3,1)))
并且您有5616000
样本可以用来对您选择的batch_size = 32
进行model.fit
。 因此,在每次迭代中,将从5616000中抽取32个样本并进行训练。
更新1
对于您的用例,而不是使用Conv1D
,可以使用Dense
,因为您只有3个功能。 我建议这样做
train_x = train_x.values.reshape(5616000,3)
对于您的train_y
,您需要预测6个类,因此需要进行一次热编码 。 因此,对于您来说train_y
将是
train_y = keras.utils.to_categorical(train_y.values.reshape(5616000,1), num_classes=6)
然后您的模型将是这样的,
model = Sequential()
model.add(InputLayer(input_shape=(3,)))
model.add(Dense(8,activation='relu'))
model.add(Dense(10,activation='relu'))
model.add(Dense(8,activation='relu'))
model.add(Dense(6,activation='softmax'))
您的用例是多类分类,请在这里找到多类和二进制分类之间的区别。 现在模型编译将是,
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['acc'])
和适合将是相同的
model.fit(train_x,train_y,epochs = 300, batch_size = 32, validation_split=0.33, shuffle=False)
更新2
如果您认为密集层还不够,也可以尝试这样做,因为您必须在Dense
层中添加更多的神经元或添加更多隐藏层,这样您将获得更多数量的可训练参数,这可能会有所帮助。 因此,您只有三个功能,在Conv1D
,必须根据不这样选择内核大小,即model.add(Conv1D(100,700,activation='relu',input_shape=(3,1)))
。 我认为它也应该起作用。
model=Sequential()
model.add(InputLayer(input_shape=(3,1)))
model.add(Conv1D(100, 2, activation='relu'))
model.add(Conv1D(100, 2, activation='relu'))
model.add(Conv1D(128, 1, activation='relu'))
model.add(Conv1D(128, 1, activation='relu'))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(6, activation='softmax'))
model.summary()
Layer (type) Output Shape Param #
=================================================================
conv1d_1 (Conv1D) (None, 2, 100) 300
_________________________________________________________________
conv1d_2 (Conv1D) (None, 1, 100) 20100
_________________________________________________________________
conv1d_3 (Conv1D) (None, 1, 128) 12928
_________________________________________________________________
conv1d_4 (Conv1D) (None, 1, 128) 16512
_________________________________________________________________
dropout_1 (Dropout) (None, 1, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 128) 0
_________________________________________________________________
dense_1 (Dense) (None, 6) 774
=================================================================
Total params: 50,614
Trainable params: 50,614
Non-trainable params: 0
_________________________________________________________________
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.