简体   繁体   中英

Unbalanced Binary Classification in Tensorflow

I am trying to perform a binary classification using tensorflow (V.1.1.0) with a single neuron at the output layer. The snippet below corresponds to the loss function and optimizer I am currently using (inspired from the answer here ).

ratio=.034 #minority/population ratio
learning_rate=0.001
class_weight=tf.constant([[ratio,1.0-ratio]],name='unbalanced_ratio') #weight vector, (lab_feed is one_hot labels)
weight_per_label=tf.transpose(tf.matmul(lab_feed,tf.transpose(class_weight)),name='weights_per_label')
xent=tf.multiply(weight_per_label,tf.nn.sigmoid_cross_entropy_with_logits(labels=lab_feed,logits=output),name='loss')
loss=tf.reduce_mean(xent)
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate,name='GradientDescent').minimize(loss)

My issue is however that for some reason all instances are classified as the same class after progression of epochs. Do I have to stop training in the middle or is there something wrong with the loss function?

在此处输入图片说明

You are misusing sigmoid cross-entropy as if it were softmax cross-entropy.

Sigmoid cross-entropy is adapted to binary classification — your problem is binary classification, so that's fine. But then, the output of your net should have only one channel per binary classification task — in your case, you have a single binary classification task, so your net should have one output channel only.

To balance a sigmoid cross-entropy you need to balance each individual part of the cross-entropy, ie the part coming from the positive and the part coming from the negative. This cannot be done on the output, as you are doing, because the output is already a sum of the positive and negative parts.

Hopefully there is a function in tensorflow to do just that, tf.nn.weighted_cross_entropy_with_logits . Its use is similar to tf.nn.sigmoid_cross_entropy with an additional parameter corresponding to the weight of the positive class.

What you are currently doing, is having two binary classifiers on two different channels, and sending only the negative samples to the first and the positives samples to the second. This cannot produce something useful.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM