简体   繁体   中英

Different results for tensorflow evaluate and predict (F1-Score)

I am using tf 2.5 to evaluate a multiclass classification problem. I am using F1 score since my dataset is highly imbalanced. The F1 metric I am using is from the tensorflow-addons package. When I use it with a binary model everything works fine, but results and training gets weird when I am doing multiclass models.

During training and evaluation of the multiclass problem, the F1 score is way higher than it should be. In order to check if the score was correct I used scikit-learns F1 score metric and it gave a much more reasonable result. Interestingly, when manually evaluating the prediction with the tfa F1 metric using update_states() the score is the same as scikit-learns. I am not sure about the reason for that. Probably because evaluate() and fit() use batches? But how could I overcome this problem? For evaluation its not so much of a problem, since I can just use predict. But how can I show a valid F1 training score.

Example F1-Score definition for my 7 class problem

tfa.metrics.F1Score(num_classes=7, average='macro', threshold=0.5)

Training

model.fit(ds.train_ds,validation_data=ds.val_ds,epochs=EPOCHS)
F1: 0.4163

Evaluation results

model.evaluate(ds.test_ds)
F1: 0.44059306383132935

Prediction

pred = model.predict(ds.test_ds)
metric = tfa.metrics.F1Score(num_classes=7, average='macro', threshold=0.5)
metric.update_state(y_true, y_pred)
result = metric.result()
result.numpy()
F1: 0.1444352

Scikit-Evaluation

from sklearn.metrics import f1_score
print(f1_score(y_true, y_pred, average='macro'))
F1: 0.1444351874222774

The problem was that the test dataset shuffled after each full iteration. Disabling this led to consistent scores between all evaluation methods

I simply added an additional parameter for my dataset tuning function:

def __configureperformance__(self,ds,shuffle=True):
    ds = ds.cache()
    if shuffle:
        ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(self.batch_size)
    ds = ds.prefetch(buffer_size=self.AUTOTUNE)
    return ds

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM