簡體   English   中英

python中CNN多類圖像分類的邊界框預測

[英]Bounding box prediction on CNN multiple class image classification in python

我對4種類型的特定對象進行了訓練和測試。 我也有csv格式的裝訂框條件/關注區域坐標(x,y,w,h)。 該項目的主要目的是預測測試圖像的類別以及感興趣區域周圍的邊框,並在圖像上打印類別名稱。

我已經基於keras庫應用了CNN模型。 它對測試集的給定圖像進行分類。 為了預測給定測試圖像的邊界框坐標,我應該改變什么?

        from keras.models import Sequential
        from keras.layers import Convolution2D
        from keras.layers import MaxPooling2D
        from keras.layers import Flatten
        from keras.layers import Dense

        #CNN initializing
        classifier= Sequential()

        #convolutional layer
        classifier.add(Convolution2D(filters = 32, kernel_size=(3,3), data_format= "channels_last", input_shape=(64, 64, 3), activation="relu"))

        #Pooling
        classifier.add(MaxPooling2D(pool_size=(2,2)))

        #addition of second convolutional layer
        classifier.add(Convolution2D(filters = 32, kernel_size=(3,3), data_format= "channels_last", activation="relu"))
        classifier.add(MaxPooling2D(pool_size=(2,2)))

        #step 3 - FLatttening
        classifier.add(Flatten())

        #step 4 - Full connection layer
        classifier.add(Dense(128, input_dim = 11, activation = 'relu'))
        #output layer
        classifier.add(Dense(units = 4, activation = 'sigmoid'))

        #compiling the CNN
        classifier.compile(optimizer='adam',loss="categorical_crossentropy",metrics =["accuracy"])

        #part 2 -Fitting the CNN to the images


        from keras.preprocessing.image import ImageDataGenerator

        train_datagen = ImageDataGenerator(rescale = 1./255,
                                           shear_range = 0.2,
                                           zoom_range = 0.2,
                                           horizontal_flip = True)

        test_datagen = ImageDataGenerator(rescale = 1./255)

        training_set = train_datagen.flow_from_directory('dataset/Train',
                                                         target_size = (64, 64),
                                                         batch_size = 32,
                                                         class_mode = 'categorical')

        test_set = test_datagen.flow_from_directory('dataset/Test',
                                                    target_size = (64, 64),
                                                    batch_size = 32,
                                                    class_mode = 'categorical')

        classifier.fit_generator(training_set,
                                 steps_per_epoch =4286/32,
                                 epochs = 25,
                                 validation_data = test_set,
                                 validation_steps = 44/32)

您描述的任務是對象檢測,通常需要更復雜的CNN模型。 檢查https://github.com/fizyr/keras-retinanet以了解著名的神經網絡架構之一。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM