[英]Limit neural network output to subset of trained classes
是否可以將矢量傳遞給訓練好的神經網絡,因此它只能從訓練中識別的類的子集中進行選擇。 例如,我有一個訓練有識別數字和字母的網絡,但我知道我下次運行它的圖像不會包含小寫字母(例如序列號圖像)。 然后我傳遞了一個向量,告訴它不要猜任何小寫字母。 由於類是獨占的,因此網絡以softmax函數結束。 以下只是我想到的嘗試的例子,但沒有一個真正起作用。
import numpy as np
def softmax(arr):
return np.exp(arr)/np.exp(arr).sum()
#Stand ins for previous layer/NN output and vector of allowed answers.
output = np.array([ 0.15885351,0.94527385,0.33977026,-0.27237907,0.32012873,
0.44839673,-0.52375875,-0.99423903,-0.06391236,0.82529586])
restrictions = np.array([1,1,0,0,1,1,1,0,1,1])
#Ideas -----
'''First: Multilpy restricted before sending it through softmax.
I stupidly tried this one.'''
results = softmax(output*restrictions)
'''Second: Multiply the results of the softmax by the restrictions.'''
results = softmax(output)
results = results*restrictions
'''Third: Remove invalid entries before calculating the softmax.'''
result = output*restrictions
result[result != 0] = softmax(result[result != 0])
所有這些都有問題。 第一個導致無效選擇默認為:
1/np.exp(arr).sum()
因為對softmax的輸入可能是負的,這會增加給予無效選擇的概率並使答案變得更糟。 (在我嘗試之前應該調查一下。)
第二個和第三個都有類似的問題,因為他們等到答復給予應用限制之前。 例如,如果網絡正在查看字母l,但它開始確定它是數字1,那么直到最后這些方法才會更正。 因此,如果它是以0.80的概率給出1的輸出,但隨后該選項被刪除,似乎剩余的選項將重新分配,並且最高有效答案將不如80%有信心。 其余選項最終更加同質化。 我想說的一個例子:
output
Out[75]: array([ 5.39413513, 3.81445419, 3.75369546, 1.02716988, 0.39189373])
softmax(output)
Out[76]: array([ 0.70454877, 0.14516581, 0.13660832, 0.00894051, 0.00473658])
softmax(output[1:])
Out[77]: array([ 0.49133596, 0.46237183, 0.03026052, 0.01603169])
(數組被命令使其更容易。)在原始輸出中,softmax給出.70,答案是[1,0,0,0,0],但如果這是一個無效的答案,從而刪除了再分配如何分配剩下的4個概率低於50%的選項,由於使用率太低而很容易被忽略。
我已經考慮過將一個向量作為另一個輸入傳遞到網絡中,但是我不知道如何做到這一點而不需要它知道向量告訴它做什么,我認為這將增加訓練所需的時間。
編輯:我在評論中寫得太多,所以我只是在這里發布更新。 我最終嘗試將限制作為網絡的輸入。 我采用了一個熱編碼的答案並隨機添加了額外的啟用類來模擬答案密鑰,並確保正確的答案始終在密鑰中。 當密鑰具有非常少的啟用類別時,網絡嚴重依賴它並且它干擾圖像的學習特征。 當密鑰有很多啟用的類別時,它似乎完全忽略了密鑰。 這可能是一個需要優化的問題,我的網絡架構問題,或者只需要對培訓進行調整,但我從來沒有解決過這個問題。
我確實發現刪除答案和歸零幾乎是相同的,當我最終減去np.inf
而不是乘以0.我知道合奏但是在第一個回復的評論中提到我的網絡正在處理CJK字符(字母表)只是為了讓例子更容易)並且有3000多個課程。 網絡已經過於笨重,這就是為什么我想研究這種方法。 對於每個單獨的類別使用二進制網絡是我沒有想到的,但3000+網絡似乎也有問題(如果我理解你正確說的話),盡管我可能會在以后查看。
首先,我將松散地瀏覽您列出的可用選項,並添加一些有利可圖的替代方案。 構建這個答案有點難,但我希望你能得到我想要的東西:
顯然,如你所寫的那樣,可能會給出歸零條目提供更高的機會,這似乎是一開始時的錯誤方法。
替代方案:使用smallest
logit值替換不可能的值。 這個類似於softmax(output[1:])
,盡管網絡對結果更加不確定。 示例pytorch
實現:
import torch
logits = torch.Tensor([5.39413513, 3.81445419, 3.75369546, 1.02716988, 0.39189373])
minimum, _ = torch.min(logits, dim=0)
logits[0] = minimum
print(torch.nn.functional.softmax(logits))
產量:
tensor([0.0158, 0.4836, 0.4551, 0.0298, 0.0158])
是的, 當你這樣做時 ,你就是對的 。 更重要的是,這一類的實際概率實際上遠低於14%
( tensor([0.7045, 0.1452, 0.1366, 0.0089, 0.0047])
)。 通過手動更改輸出,您實際上正在破壞NN已經學習的屬性(以及它的輸出分布),使得計算的某些部分毫無意義。 這指出了這次賞金中提到的另一個問題:
我可以想象這會以多種方式解決:
創建多個神經網絡,並通過在末尾采用argmax
(或softmax
然后是`argmax)對logits進行求和來對它們進行整合。 具有不同預測的 3種不同模型的假設情況:
import torch
predicted_logits_1 = torch.Tensor([5.39413513, 3.81419, 3.7546, 1.02716988, 0.39189373])
predicted_logits_2 = torch.Tensor([3.357895, 4.0165, 4.569546, 0.02716988, -0.189373])
predicted_logits_3 = torch.Tensor([2.989513, 5.814459, 3.55369546, 3.06988, -5.89473])
combined_logits = predicted_logits_1 + predicted_logits_2 + predicted_logits_3
print(combined_logits)
print(torch.nn.functional.softmax(combined_logits))
這將在softmax
之后給出以下概率:
[0.11291057 0.7576356 0.1293983 0.00005554 0.]
(注意第一堂課現在最有可能)
您可以使用bootstrap聚合和其他集成技術來改進預測。 這種方法使分類決策表面更平滑,並修復了分類器之間的相互錯誤(假設它們的預測變化很大)。 需要更多細節來描述許多帖子(或者需要具體問題的單獨問題), 這里或這里有一些可能會讓你開始。
我仍然不會將這種方法與手動選擇輸出結合起來。
如果您可以將其分布在多個GPU上,這種方法可能會產生更好的推理時間,甚至可能產生更好的訓練時間。
基本上,你的每一類都可以存在( 1
)或不存在( 0
)。 原則上,您可以為N
類訓練N
神經網絡,每個類輸出一個無界數(logit)。 這個單一的數字告訴網絡是否認為這個例子應該歸類為它的類。
如果你確定某些類不會是結果肯定你沒有運行網絡負責此類檢測 。 在從所有網絡(或網絡子集)獲得預測之后,您選擇最高值(或者如果您使用sigmoid
激活,則為最高概率,盡管這會在計算上浪費)。
如果需要,額外的好處是所述網絡的簡單性(更容易訓練和微調)和switch-like
行為。
如果我是你,我會采用2.2中概述的方法,因為你可以輕松地節省一些推理時間,並允許你以合理的方式“選擇輸出”。
如果這種方法還不夠,你可以考慮網絡的N
集合,所以混合使用2.2和2.1 ,一些自舉或其他集成技術。 這也可以提高您的准確性。
首先問問自己:基於外部數據排除某些輸出的好處是什么。 在你的帖子中,我不明白你為什么要排除它們。
保存它們不會節省計算,因為連接(或神經元)對多個輸出有影響:你不能禁用連接/神經元。
是否真的有必要排除某些類? 如果您的網絡訓練得足夠好,它就會知道它是不是資本。
所以我的答案是:我認為你不應該在softmax 之前進行任何操作。 這會給你錯誤的結論。 所以你有以下選擇:
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.