简体   繁体   中英

Tensorflow Argmax equivalent for multilabel classification

I want to do evaluation of a classification Tensorflow model.

To compute the accuracy, I have the following code :

predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(labels=label_ids, predictions=logits)

It work well in single label classification, but now I want to do multilabel classification, where my labels are Array of Integers instead of Integers.

Here is an example of label [0, 1, 1, 0, 1, 0] that are stored in label_ids , and an example of predictions [0.1, 0.8, 0.9, 0.1, 0.6, 0.2] from the Tensor logits

What function should I use instead of argmax to do so ? (My labels are arrays of 6 Integers with value of either 0 or 1)

If needed, we can suppose that there is a threshold of 0.5.

It is probably better to do this type of post-processing evaluation outside of tensorflow, where it is more natural to try several different thresholds.

If you want to do it in tensorflow, you can consider:

predictions = tf.math.greater(logits, tf.constant(0.5))

This will return a tensor of the original logits shape with True for all entries greater than 0.5. You can then calculate accuracy as before. This is suitable for cases where many labels can be simultaneously true for a given sample.

Use below code to caclutae accuracy in multiclass classification:

tf.argmax will return the axis where y value is max for both y_pred and y_true (actual y).

Further tf.equal is used to find total number of matches (It returns True, False).

Convert the boolean into float(ie 0 or 1) and use tf.reduce_mean to calculate the accuracy.

correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))

Edit

Example with data:

import numpy as np

y_pred = np.array([[0.1,0.5,0.4], [0.2,0.6,0.2], [0.9,0.05,0.05]])
y_true = np.array([[0,1,0],[0,0,1],[1,0,0]])

correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))

with tf.Session() as sess:
  # print(sess.run([correct_mask]))
  print(sess.run([accuracy]))

Output:

[0.6666667]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM