简体   繁体   中英

Tensorflow: Weighted sparse softmax with cross entropy loss

I am doing image segmentation using fully convolutional neural networks (link to the paper): https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf

This can be considered as pixel classification (in the end each pixel is getting a label)

I am using the tf.nn.sparse_softmax_cross_entropy_with_logits loss function.

loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                      labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                      name="entropy"))) 

Everything is going well. However, I saw that one class occurs in the vast majority of pixels (95%+), call this class 0. Lets say that we have another three classes, 1, 2 and 3.

What would be the easiest way to put weights to the classes? Essentially, I would like to have very low weight for class 0 (like 0.1) compared to the other three classes who should have normal weight 1.

I know that this function exists: https://www.tensorflow.org/api_docs/python/tf/losses/sparse_softmax_cross_entropy

It just looks to me that it does something totally different and I do not understand how the weights should have the same rank as labels. I mean, in my case, weights should be something like Tensor([0.1, 1, 1, 1]) so shape (4,) and rank 1, while labels have shape (batch_size, width, height) and so rank 3. Am I missing something?

The equivalent on PyTorch would be

torch.nn.CrossEntropyLoss(weight=None, size_average=True, ignore_index=-100)

where weight is a torch tensor [0.1, 1, 1, 1]

Thanks!

Your guess is correct, the weights parameter in tf.losses.softmax_cross_entropy and tf.losses.sparse_softmax_cross_entropy means the weights across the batch , ie make some input examples more important than others. There's no out-of-the-box way to weight the loss across classes .

What you can do as a workaround, is specially pick the weights according to the current labels and use them as batch weights. This means that the weights vector will be different for each batch, but will try to make occasional rare classes more important. See the sample code in this question .

Note: since the batches not necessarily contain uniform class distribution, this trick works poorly with small batch size and gets better with the larger batch size. When the batch size is 1, it's completely useless. So make the batches as big as possible.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM