[英]Keras: reshape to connect lstm and conv
这个问题也存在于github问题中 。 我想在Keras中构建一个包含2D卷积和LSTM层的神经网络。
网络应该对MNIST进行分类。 MNIST中的训练数据是从0到9的60000个手写数字的灰度图像。每个图像是28×28像素。
我将图像分成四个部分(左/右,上/下),并按四个顺序重新排列,以获得LSTM的序列。
| | |1 | 2|
|image| -> ------- -> 4 sequences: |1|2|3|4|, |4|3|2|1|, |1|3|2|4|, |4|2|3|1|
| | |3 | 4|
其中一个小子图像的尺寸为14×14。四个序列沿宽度堆叠在一起(无论宽度或高度都无关紧要)。
这将创建一个形状为[60000,4,1,56,14]的向量,其中:
现在应该给Keras模型。 问题是改变CNN和LSTM之间的输入尺寸。 我在网上搜索并发现了这个问题: Python keras如何在卷积层之后将输入的大小更改为lstm层
该解决方案似乎是一个Reshape图层,它使图像变平,但保留了时间步长(与Flatten图层相反,它会折叠除了batch_size之外的所有内容)。
到目前为止,这是我的代码:
nb_filters=32
kernel_size=(3,3)
pool_size=(2,2)
nb_classes=10
batch_size=64
model=Sequential()
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
border_mode="valid", input_shape=[1,56,14]))
model.add(Activation("relu"))
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Reshape((56*14,)))
model.add(Dropout(0.25))
model.add(LSTM(5))
model.add(Dense(50))
model.add(Dense(nb_classes))
model.add(Activation("softmax"))
此代码创建一条错误消息:
ValueError:新数组的总大小必须保持不变
显然,Reshape图层的输入不正确。 作为替代方案,我也尝试将时间步长传递给Reshape图层:
model.add(Reshape((4,56*14)))
这感觉不对,在任何情况下,错误都保持不变。
我这样做是对的吗? Reshape图层是连接CNN和LSTM的合适工具吗?
这个问题有相当复杂的方法。 例如: https : //github.com/fchollet/keras/pull/1456一个TimeDistributed Layer,它似乎隐藏了后续层的时间步长度。
或者: https : //github.com/anayebi/keras-extra一组用于组合CNN和LSTM的特殊层。
为什么有这么复杂(至少它们对我来说似乎很复杂)的解决方案,如果一个简单的重塑成功呢?
更新 :
令人尴尬的是,我忘记了尺寸将通过汇集和(因为没有填充)卷积而改变。 kgrm建议我使用model.summary()
来检查尺寸。
model.add(Reshape((32*26*5,)))
图层之前的图层输出为(None, 32, 26, 5)
model.add(Reshape((32*26*5,)))
(None, 32, 26, 5)
,我将model.add(Reshape((32*26*5,)))
更改为: model.add(Reshape((32*26*5,)))
。
现在ValueError消失了,相反LSTM抱怨:
例外:输入0与层lstm_5不兼容:预期ndim = 3,发现ndim = 2
好像我需要通过整个网络传递时间步长维度。 我怎样才能做到这一点 ? 如果我将它添加到Convolution的input_shape,它也会抱怨: Convolution2D(nb_filters, kernel_size[0], kernel_size[1], border_mode="valid", input_shape=[4, 1, 56,14])
例外:输入0与图层卷积2d_44不兼容:预期ndim = 4,发现ndim = 5
根据Convolution2D定义,您的输入必须是4维的维度(samples, channels, rows, cols)
。 这是您收到错误的直接原因。
要解决这个问题,您必须使用TimeDistributed wrapper。 这允许您在整个时间内使用静态(非重复)层。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.