簡體   English   中英

如何正確使用 tfa.metrics.F1Score 和 image_dataset_from_directory?

[英]How to use tfa.metrics.F1Score with image_dataset_from_directory correctly?

Colab 代碼在這里

我正在關注此處的文檔以獲得多類預測的結果

當我訓練使用

#last layer
tf.keras.layers.Dense(2, activation='softmax')

model.compile(optimizer="adam",
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=[tf.keras.metrics.CategoricalAccuracy(),
                       tfa.metrics.F1Score(num_classes=2, average='macro')])

我明白了

144/144 [==] - 8s 54ms/step - loss: 0.0613 - categorical_accuracy: 0.9789 - f1_score: 0.9788 - val_loss: 0.0826 - val_categorical_accuracy: 0.9725 - val_f1_score: 0.9722

當我做:

model.evaluate(val_ds)

我明白了

16/16 [==] - 0s 15ms/step - loss: 0.0826 - categorical_accuracy: 0.9725 - f1_score: 0.9722
[0.08255868405103683, 0.9725490212440491, 0.9722140431404114]

我想在官方網站上使用metric.result 當我加載下面的代碼時,我得到0.4875028這是錯誤的。 如何獲得正確的true_categories predicted_categories

metric = tfa.metrics.F1Score(num_classes=2, average='macro')

predicted_categories = model.predict(val_ds)
true_categories = tf.concat([y for x, y in val_ds], axis=0).numpy() 

metric.update_state(true_categories, predicted_categories)
result = metric.result()
print(result.numpy())

#0.4875028

這是我加載數據的方式

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    main_folder,
    validation_split=0.1,
    subset="training",
    label_mode='categorical',
    seed=123,
    image_size=(dim, dim))

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    main_folder,
    validation_split=0.1,
    subset="validation",
    label_mode='categorical',
    seed=123,
    image_size=(dim, dim))

來自: https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory

tf.keras.preprocessing.image_dataset_from_directory(
    directory, labels='inferred', label_mode='int',
    class_names=None, color_mode='rgb', batch_size=32, image_size=(256,
    256), shuffle=True, seed=None, validation_split=None, subset=None,
    interpolation='bilinear', follow_links=False
)

默認情況下, shuffleTrue ,這對您的val_ds來說是個問題,我們不想隨機播放。

正確的指標是訓練期間報告的指標; 此外,我建議您也可以手動檢索驗證數據集並在對其進行預測后檢查指標(不一定通過flow_from_directory() )。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM