简体   繁体   中英

Correcting input dimensions for CNN, LSTM based classifier using Keras, Python

I'm working on implementing a 2D (Perhaps 1D) CNN+ LSTM classifier for Network Traffic classification purposes. The CNN will essentially be used as a feature extractor and the LSTM would work for the classification.

I have used the TimeDistributed layer to help combine the CNN and LSTM layers together (Code attached.) Since the input size varies dynamically, the number of data points has been indicated with None. no_rows=20 (Number of packets considered per flow for classification) no_cols=7 (Number of features considered for each packet)

Despite using the TimeDistributed layer wrap, I am facing some input dimension issues. Not quite sure how to resolve this.

Using Reshape as a layer to resolve this was one of the many fixes I came across but didn't work. Kindly let me know how to build this structure and how to fix my code.

Thanks !

(Using a Linux based AWS instance, Ubuntu 16.04 and Tensorflow backend to implement the code)

  1. Used Reshape layer from Keras core layers to fix the output of the CNN but did not resolve the issue.
  2. Had to remove the Flatten layer and replace it with GlobalMaxPooling2D layer due to the presence of dynamically changing input size.
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.30, random_state = 36)

model = Sequential()

# Adding CNN Model layers

model.add(TimeDistributed(Conv2D(32, kernel_size = (4 , 2), strides = 1, padding='valid', activation = 'relu', input_shape = (None,20,7,1))))
model.add(TimeDistributed(BatchNormalization()))
model.add(TimeDistributed(Conv2D(64, kernel_size = (4 , 2), strides = 1, padding='valid', activation = 'relu')))
model.add(TimeDistributed(BatchNormalization()))
#model.add(TimeDistributed(Reshape((-1,1))))
model.add(TimeDistributed(GlobalMaxPooling2D()))
#model.add(Reshape((1,1)))

# Adding LSTM layers

model.add(LSTM(128, recurrent_dropout=0.2))
model.add(Dropout(rate = 0.2))
model.add(Dense(100))
model.add(Dropout(rate = 0.4))
model.add(Dense(108))

model.add(Dense(num_classes,activation='softmax'))


# Compiling this model

model.compile(loss = 'categorical_crossentropy', optimizer='adam',metrics = ['accuracy'])
#print(model.summary())

#Training the Network

history = model.fit(X_train, Y_train, batch_size=32, epochs = 1, validation_data=(X_test, Y_test))

Running the code snippet mentioned above I face the following error message:

"Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
ValueError: Input tensor must be of rank 3, 4 or 5 but was 2.

One thing you can do is to make a batch of a fixed input (number of frames) from your video source and process that. The code for doing that would be:

def get_data(video_source, batch_size):
    x_data = []
    #Reading the Video from file path
    cap = cv2.VideoCapture(video_source)    
    for i in range(batch_size):
        #To Store Frames
        frames = []
        for j in range(frame_to_process): #here we get frame_to_process
            ret, frame = cap.read()
            if not ret:
                # print('No frames found!')
                break
            # converting to frmae gray
            # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 
            # resizing frame to a particular input shape
            # frame = cv2.resize(frame,(30,30),interpolation=cv2.INTER_AREA)
            frames.append(frame)
        # appending each batch
        x_data.append(frames)
    return x_data

# number of frames in each batch
frame_to_process = 30
# size of each batch
batch_size = 32
# make batch of video inputs
X_data = np.array(get_data(video_source, batch_size, frame_to_process))

Tip: Also, instead of using Conv2D with TimeDistrubuted you can use ConvLSTM which can give a little performance improvement.

Anyways, if you want to process frames dynamically you can convert the code to Pytorch which has Dynamic Graphs , where you can give input with variable batch size.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM