[英]training using triplet loss: Nan in summary histogram tensorflow
我正在嘗試使用三元組損失來訓練 CNN model。 我有 8 個類(產品)的圖片,每個 class 都有大約 100 張圖片
The.network 架構看起來像:
input image -> conv1 -> conv2 -> conv3 -> conv4 -> conv5 -> 28D embedding
182x182 filters 7x7 5x5 3x3 1x1 1x1
num_outputs 32 64 128 256 28
activation Relu Relu Relu Relu Relu
因此,.network給出了一個28維的embedding。 然而,在訓練期間,它在隨機迭代步驟中拋出以下錯誤:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1
我已經嘗試過一些超參數但仍然沒有運氣,只有它拋出錯誤的迭代步驟發生了變化。 以下是超參數,我正在嘗試:
batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)
任何輸入都會非常有幫助。
我使用更大的批量解決了這個問題。 當批次中沒有三元組時會發生錯誤。
發生這種情況是因為有一批沒有正對。 我創建了一個用於生成兼容批次的包。 https://github.com/ma7555/kerasgen
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.