简体   繁体   中英

Guiding tensorflow keras model training to achieve best Recall At Precision 0.95 for binary classification

I am hoping to get some help on the titular topic. I have a database of medical data of patients with two similar pathologies, one severe and one much less so. I need flag most of the formers (≥95%) and leave out as many of the latter as possible.

Therefore, I want to create a binary classifier that reflects this. Looking around on the web (not an expert) I put together this piece of code, substituting the metric I found with RecallAtPrecision(0.95) in the middle part of the code. Below is an abridged version:

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(10, input_dim=x_train.shape[1], activation='relu', kernel_initializer='he_normal'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.RecallAtPrecision(0.95)])

history = model.fit(x_train, y_train, validation_split=0.33, batch_size=16, epochs=EPOCHS)

However, it simply doesn't work, as it throws the following error:

AttributeError: module 'tensorflow_core.keras.metrics' has no attribute 'RecallAtPrecision'

I am at a loss about why that happened, as I can clearly see it in the documentation . The code works if I use Recall(), Precision() or most any other metrics. Looking around some more, I am beginning to think I am missing something fundamental. Do any of you fine ladies and gentlemen have any pointers on how to solve this problem?

To calculate precision and recall, you don't need require Keras. If you have your actual and expected values as vectors of 0/1, you can calculate TP, FP, FN using tf.count_nonzero , you can easily represent them.

TP = tf.count_nonzero(predicted * actual)
FP = tf.count_nonzero(predicted * (actual - 1))
FN = tf.count_nonzero((predicted - 1) * actual)

Your metrics are now simple to calculate:

precision = TP / (TP + FP)
recall = TP / (TP + FN)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM