简体   繁体   中英

Combine CNN with LSTM

I'm looking to implement a RNN along with a CNN in order to make a prediction based on two images instead of one alone with a CNN. I'm trying to modify the alexnet model code:

def alexnet(width, height, lr, output=3):
    network = input_data(shape=[None, width, height, 1], name='input')
    network = conv_2d(network, 96, 11, strides=4, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = conv_2d(network, 256, 5, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = conv_2d(network, 384, 3, activation='relu')
    network = conv_2d(network, 384, 3, activation='relu')
    network = conv_2d(network, 256, 3, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = fully_connected(network, 4096, activation='tanh')
    network = dropout(network, 0.5)
    network = fully_connected(network, 4096, activation='tanh')
    network = dropout(network, 0.5)
    network = fully_connected(network, output, activation='softmax')
    network = regression(network, optimizer='momentum',
                         loss='categorical_crossentropy',
                         learning_rate=lr, name='targets')

    model = tflearn.DNN(network, checkpoint_path='model_alexnet',
                        max_checkpoints=1, tensorboard_verbose=0, tensorboard_dir='log')

    return model

I have my images in a np array where each element is the pixel data for one image. I'm having trouble implementing the functionality of using two images with the RNN.

I've seen the reshape and lstm methods of tflearn which I believe should be placed before the final fully connected layer but not sure how to specify the number of images to use.

Also, would this be easier to implement with Keras?

If I understood you correctly, you need to do the following. Let model be the network taking series of images as input and returning the predictions. Using finctional API, this schematically looks as follows:

    def create_model():
        input_data = keras.Input(shape=(number-of-images,shape-of-images))
        ### processing part ###
        model = keras.Model(input_images, your-predictions)
        return model
    model = create_model()

In the processing part you want to obtain encoding for each of the images and then analyze them as a sequence using RNN.

As the first step, you need to obtain encoding for all of the images. Let encoder be the network making encodings for individual images, returning enc_dim -dimensional encodings. To obtains encodings for all images efficiently, note that during training the model processes data with the shape (batch-size,number-of-images,shape-of-images) . Thus, in total you have total-number-of-images = (batch-size) x (number-of-images) of images. To process them, reshape the input_data to have dimension (total-number-of-images,shape-of-images) as follows:

    input_data_reshaped = tf.reshape(input_data, (-1,shape-of-images)),

and pass them through the encoder :

    image_encodings_flatterned = encoder(input_data_reshaped).

This will produce the output of the form (total-number-of-images,enc_dim) . To process the encodings, you need to restore batch-size dimension. This can be is easily done:

    image_encodings = tf.reshape(image_encodings_flatterned, (-1,number-of-images,enc_dim))

As expected, it will reshape the data to (batch-size,number-of-images,enc_dim) . This data can be readily processed by RNN layer or combination thereof. For example, for a single LSTM layer,

    rnn_analyzer = tf.keras.layers.LSTM(parameters)

the predictions can be obtained as follows:

    rnn_encodings = rnn_analyzer(image_encodings).

rnn_encodings can be further used by dense layers to make final predictions.

By placing the above in the processing part of the model , you will reach the goal.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM