简体   繁体   中英

How to use pre-trained weight for training convolutional NN in tensorflow?

In my experiment, I want to train convolutional NN (CNN) with cifar10 on imagenet, and I used ResNet50 . Since cifar10 is 32x32x3 set of images while ResNet50 uses 224x224x3. To do so, I need to resize input image in order to train CNN on imagenet . However, I came up following up attempt to train simple CNN on imagenet:

my current attempt :

Please see my whole implementation in this gist :

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = Conv2D(32, (3, 3))(base_model.output)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Flatten()(x)
x = Dense(256)(x)
x = Dense(10)(x)
x = Activation('softmax')(x)
outputs = x
model = models.Model(base_model.input, outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=50, epochs=3, verbose=1, validation_data=(X_test, y_test))

but this attempt gave me ResourceExhaustedError ; I occurred this error before and changing batch_size removed the error. But now even I changed batch_size as small as possible, and still end up with error. I am wondering the way of training CNN on imagenet on above may not be correct or something wrong in my attempt.

update :

I want to understand how about using pre-trained weights (ie, ResNet50 on imagenet) to train convolutional NN; I am not sure how to get this done in tensorflow. Can anyone provide possible feasible approach to get this right? Thanks

Can anyone point me out what went wrong with my attempt? What would be correct way of training state-of-art CNN model with cifar10 on imagenet? Can anyone share possible thoughts or efficient way of doing this in tensorflow? Any idea? Thanks!

You might be getting this error because you are trying to allocate the memory (RAM) to the whole data at once. For starters, you might be using numpy arrat to store the images, then those images are being converted to tensors . So you have 2X the memory already even before creating anything. On top of that, resnet is very heavy model so you are trying to pass the whole data at once. That is why the models work with batches . Try to create a generator by using tf.data.Dataset documentation or use the very easy keras.preprocessing.Image.ImageDataGenerator class. It is very easy to use. You can save address of your image files in the Datarame column with another column representing the class and use .flow_from_directory . Or you can use flow_from_directory if you have your images saved in the directory.

Checkout the documentation

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM