简体   繁体   中英

training using triplet loss: Nan in summary histogram tensorflow

I am trying to train a CNN model using triplet loss. I have images of 8 classes (products) and each class has around 100 images each
The.network architecture looks like:

input image   ->          conv1   ->   conv2    ->   conv3   ->   conv4   -> conv5 -> 28D embedding
 182x182     filters       7x7         5x5            3x3          1x1        1x1
             num_outputs    32          64            128          256         28
             activation     Relu       Relu           Relu         Relu       Relu

Therefore,.network gives a 28-D embedding. During training however, it throws me the below error at random iteration step:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1

I have played around with some of the hyperparameters but still no luck, only the iteration step at which it throws the error is changing. Below is the hyperparameters, I'm trying out:

batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)

Any inputs will be really helpful.

I resolved the issue using a larger batch size. The error occurs when there are no triplets in a batch.

This happens due to having a batch with no positive pairs. I have created a packaged for generating compatible batches. https://github.com/ma7555/kerasgen

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM