簡體   English   中英

在Tensorflow中的多類別分類中限制輸出類別

[英]Restricting output classes in multi-class classification in Tensorflow

我正在構建雙向LSTM以進行多類句子分類。 我總共有13個類別可供選擇,並且我將LSTM網絡的輸出乘以一個維度為[2*num_hidden_unit,num_classes]的矩陣,然后應用softmax獲得該句子落入13個[2*num_hidden_unit,num_classes]中的1個的概率類。

因此,如果我們將output[-1]視為網絡輸出:

W_output = tf.Variable(tf.truncated_normal([2*num_hidden_unit,num_classes])) result = tf.matmul(output[-1],W_output) + bias

我得到了[1, 13]矩陣(假設我暫時不使用批處理)。

現在,我還知道,給定句子肯定不會屬於給定類別,並且我想限制給定句子考慮的類別數量。 舉例來說,對於給定的句子,我知道它只能分為6類,因此輸出實際上應該是維度矩陣[1,6]

我想到的一個選擇是在result矩陣上放置一個掩碼,在該矩陣上,我將要保留的類對應的行乘以1,將要丟棄的類對應的行乘以0,這樣我將丟失一些信息,而不是重定向信息。

任何人都知道在這種情況下該怎么辦?

我認為,最好的選擇是,如您所描述的,使用加權交叉熵損失函數,其中“不可能的類別”的權重為0,其他可能類別的權重為1。 Tensorflow具有加權交叉熵損失函數。

另一種有趣但可能效果不佳的方法是提供您現在掌握的任何信息,這些信息可以使您的句子在某個時間點(可能即將結束)可以/不能進入網絡。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM