簡體   English   中英

使用三元組損失進行訓練:摘要直方圖中的 Nan tensorflow

[英]training using triplet loss: Nan in summary histogram tensorflow

我正在嘗試使用三元組損失來訓練 CNN model。 我有 8 個類(產品)的圖片,每個 class 都有大約 100 張圖片
The.network 架構看起來像:

input image   ->          conv1   ->   conv2    ->   conv3   ->   conv4   -> conv5 -> 28D embedding
 182x182     filters       7x7         5x5            3x3          1x1        1x1
             num_outputs    32          64            128          256         28
             activation     Relu       Relu           Relu         Relu       Relu

因此,.network給出了一個28維的embedding。 然而,在訓練期間,它在隨機迭代步驟中拋出以下錯誤:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1

我已經嘗試過一些超參數但仍然沒有運氣,只有它拋出錯誤的迭代步驟發生了變化。 以下是超參數,我正在嘗試:

batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)

任何輸入都會非常有幫助。

我使用更大的批量解決了這個問題。 當批次中沒有三元組時會發生錯誤。

發生這種情況是因為有一批沒有正對。 我創建了一個用於生成兼容批次的包。 https://github.com/ma7555/kerasgen

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM