繁体   English   中英

将千层面转换为 Keras 代码(CNN -> LSTM)

[英]convert Lasagne to Keras code (CNN -> LSTM)

我想转换这个千层面代码:

et = {}
net['input'] = lasagne.layers.InputLayer((100, 1, 24, 113))
net['conv1/5x1'] = lasagne.layers.Conv2DLayer(net['input'], 64, (5, 1))
net['shuff'] = lasagne.layers.DimshuffleLayer(net['conv1/5x1'], (0, 2, 1, 3))
net['lstm1'] = lasagne.layers.LSTMLayer(net['shuff'], 128)

在 Keras 代码中。 目前我想出了这个:

multi_input = Input(shape=(1, 24, 113), name='multi_input')
y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
y = LSTM(128)(y)

但我收到错误: Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4

解决方案

from keras.layers import Input, Conv2D, LSTM, Permute, Reshape

multi_input = Input(shape=(1, 24, 113), name='multi_input')
print(multi_input.shape)  # (?, 1, 24, 113)

y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
print(y.shape)  # (?, 64, 20, 113)

y = Permute((2, 1, 3))(y)
print(y.shape)  # (?, 20, 64, 113)

# This line is what you missed
# ==================================================================
y = Reshape((int(y.shape[1]), int(y.shape[2]) * int(y.shape[3])))(y)
# ==================================================================
print(y.shape)  # (?, 20, 7232)

y = LSTM(128)(y)
print(y.shape)  # (?, 128)

说明

我把 Lasagne 和 Keras 的文档放在这里,所以你可以做交叉引用:

千层面

除了输入形状预期为(batch_size, sequence_length, num_inputs)之外,循环层可以与前馈层类似地使用

凯拉斯

输入形状

具有形状(batch_size, timesteps, input_dim) 3D 张量。


基本上 API 是一样的,但 Lasagne 可能会为你重塑(我需要稍后检查源代码)。 这就是您收到此错误的原因:

Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4

, 因为Conv2D之后的张量形状是(?, 64, 20, 113) of ndim=4

因此,解决方案是将其重塑为(?, 20, 7232)

编辑

与千层面源代码确认,它为你做的伎俩:

num_inputs = np.prod(input_shape[2:])

所以作为 LSTM 输入的正确张量形状是(?, 20, 64 * 113) = (?, 20, 7232)


笔记

PermutePermute是多余的,因为无论如何你都必须重塑。 我把它放在这里的原因是有一个从 Lasagne 到DimshuffleLaye的“完整翻译”,它做了DimshuffleLaye在 Lasagne 中所做的。

然而,由于我在Edit 中提到的原因,在 Lasagne 中需要DimshuffleLaye ,Lasagne LSTM 创建的新维度来自“最后两个”维度的乘法。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM