[英]Transfer learning on keras model always gives same predictions
我正在嘗試使用 keras 應用程序模塊訓練圖像分類器。 當我在驗證集上運行預測時,所有圖像都被預測為同一類。 它並不總是同一個班級,它在訓練期間會有所不同。 我正在使用帶有 ImageNet 權重的 MobileNetV2,但我也嘗試了其他具有相同結果的模型。
我已經嘗試使用 TensorFlow hub 中的模型,如本教程中所述:https ://www.tensorflow.org/beta/tutorials/images/hub_with_keras並且它工作正常,所以它不是數據集問題。
我的代碼片段:
image_size = 224
batch_size = 32
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input)
validation_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input)
train_generator = train_datagen.flow_from_directory(training_data_dir,
target_size=(image_size, image_size),
batch_size=batch_size)
validation_generator = train_datagen.flow_from_directory(validation_data_dir,
target_size=(image_size, image_size),
batch_size=batch_size)
IMG_SHAPE = (image_size, image_size, 3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights="imagenet")
base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(train_generator.num_classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.001),
loss="categorical_crossentropy",
metrics=["accuracy"])
model.summary()
batch_stats = CollectBatchStats()
epoch_stats = CollectEpochStats(model, validation_generator)
checkpoint = tf.keras.callbacks.ModelCheckpoint(...)
epochs = 10
steps_per_epoch = train_generator.n // train_generator.batch_size
validation_steps = validation_generator.n // validation_generator.batch_size
history = model.fit_generator(train_generator,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[batch_stats, epoch_stats, checkpoint],
workers=4,
validation_data=validation_generator,
validation_steps=validation_steps)
問題已解決:在我的代碼中,模型編譯后有以下幾行:
sess = keras_backend.get_session()
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
刪除它們后一切正常。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.