简体   繁体   English

使用Tensorflow Keras将CNN与LSTM结合

[英]Combining CNN with LSTM using Tensorflow Keras

I'm using pre-trained ResNet-50 model and want to feed the outputs of the penultimate layer to a LSTM Network. 我正在使用经过预训练的ResNet-50模型,并希望将倒数第二层的输出馈送到LSTM网络。 Here is my sample code containing only CNN (ResNet-50): 这是我的仅包含CNN(ResNet-50)的示例代码:

N = NUMBER_OF_CLASSES
#img_size = (224,224,3)....same as that of ImageNet    
base_model = ResNet50(include_top=False, weights='imagenet',pooling=None)
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(1024, activation='relu')(x)
model = Model(inputs=base_model.input, outputs=predictions)

Next, I want to feed it to a LSTM network, as follows... 接下来,我要将其馈送到LSTM网络,如下所示...

final_model = Sequential()
final_model.add((model))
final_model.add(LSTM(64, return_sequences=True, stateful=True))
final_model.add(Dense(N, activation='softmax'))

But I'm confused how to reshape the output to the LSTM input. 但是我很困惑如何将输出重塑为LSTM输入。 My original input is (224*224*3) to CNN. 我的原始输入是(224 * 224 * 3)到CNN。 Also, should I use TimeDistributed? 另外,我应该使用TimeDistributed吗?

Any kind of help is appreciated. 任何帮助都将受到赞赏。

Adding an LSTM after a CNN does not make a lot of sense, as LSTM is mostly used for temporal/sequence information, whereas your data seems to be only spatial, however if you still like to use it just use 在CNN后面添加LSTM没有多大意义,因为LSTM主要用于时间/序列信息,而您的数据似乎只是空间的,但是如果您仍然喜欢使用它,则只需使用

x = Reshape((1024,1))(x)

This would convert it to a sequence of 1024 samples, with 1 feature 这会将其转换为1024个样本的序列,具有1个功能

If you are talking of spatio-temporal data, Use Timedistributed on the Resnet Layer and then you can use convlstm2d 如果你在谈论时空数据,使用Timedistributed对RESNET层,然后你可以使用convlstm2d

Example of using pretrained network with LSTM: 在LSTM中使用预训练网络的示例:

inputs = Input(shape=(config.N_FRAMES_IN_SEQUENCE, config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))
cnn = VGG16(include_top=False, weights='imagenet', input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))
x = TimeDistributed(cnn)(inputs)
x = TimeDistributed(Flatten())(x)
x = LSTM(256)(x)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM