简体   繁体   English

在Tensorflow中的多类别分类中限制输出类别

[英]Restricting output classes in multi-class classification in Tensorflow

I am building a bidirectional LSTM to do multi-class sentence classification. 我正在构建双向LSTM以进行多类句子分类。 I have in total 13 classes to choose from and I am multiplying the output of my LSTM network to a matrix whose dimensionality is [2*num_hidden_unit,num_classes] and then apply softmax to get the probability of the sentence to fall into 1 of the 13 classes. 我总共有13个类别可供选择,并且我将LSTM网络的输出乘以一个维度为[2*num_hidden_unit,num_classes]的矩阵,然后应用softmax获得该句子落入13个[2*num_hidden_unit,num_classes]中的1个的概率类。

So if we consider output[-1] as the network output: 因此,如果我们将output[-1]视为网络输出:

W_output = tf.Variable(tf.truncated_normal([2*num_hidden_unit,num_classes])) result = tf.matmul(output[-1],W_output) + bias

and I get my [1, 13] matrix (assuming I am not working with batches for the moment). 我得到了[1, 13]矩阵(假设我暂时不使用批处理)。

Now, I also have information that a given sentence does not fall into a given class for sure and I want to restrict the number of classes considered for a given sentence. 现在,我还知道,给定句子肯定不会属于给定类别,并且我想限制给定句子考虑的类别数量。 So let's say for instance that for a given sentence, I know it can fall only in 6 classes so the output should really be a matrix of dimensionality [1,6] . 举例来说,对于给定的句子,我知道它只能分为6类,因此输出实际上应该是维度矩阵[1,6]

One option I was thinking of is to put a mask over the result matrix where I multiply the rows corresponding to the classes that I want to keep by 1 and the ones I want to discard by 0, by in this way I will just lose some of the information instead of redirecting it. 我想到的一个选择是在result矩阵上放置一个掩码,在该矩阵上,我将要保留的类对应的行乘以1,将要丢弃的类对应的行乘以0,这样我将丢失一些信息,而不是重定向信息。

Anyone has a clue on what to do in this case? 任何人都知道在这种情况下该怎么办?

I think your best bet is, as you seem to have described, using a weighted cross entropy loss function where the weights for your "impossible class" are 0 and 1 for the other possible classes. 我认为,最好的选择是,如您所描述的,使用加权交叉熵损失函数,其中“不可能的类别”的权重为0,其他可能类别的权重为1。 Tensorflow has a weighted cross entropy loss function. Tensorflow具有加权交叉熵损失函数。

Another interesting but probably less effective method is to feed whatever information you now have about what classes your sentence can/cannot fall into the network at some point (probably towards the end). 另一种有趣但可能效果不佳的方法是提供您现在掌握的任何信息,这些信息可以使您的句子在某个时间点(可能即将结束)可以/不能进入网络。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM