简体   繁体   English

使用三元组损失进行训练:摘要直方图中的 Nan tensorflow

[英]training using triplet loss: Nan in summary histogram tensorflow

I am trying to train a CNN model using triplet loss.我正在尝试使用三元组损失来训练 CNN model。 I have images of 8 classes (products) and each class has around 100 images each我有 8 个类(产品)的图片,每个 class 都有大约 100 张图片
The.network architecture looks like: The.network 架构看起来像:

input image   ->          conv1   ->   conv2    ->   conv3   ->   conv4   -> conv5 -> 28D embedding
 182x182     filters       7x7         5x5            3x3          1x1        1x1
             num_outputs    32          64            128          256         28
             activation     Relu       Relu           Relu         Relu       Relu

Therefore,.network gives a 28-D embedding.因此,.network给出了一个28维的embedding。 During training however, it throws me the below error at random iteration step:然而,在训练期间,它在随机迭代步骤中抛出以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1

I have played around with some of the hyperparameters but still no luck, only the iteration step at which it throws the error is changing.我已经尝试过一些超参数但仍然没有运气,只有它抛出错误的迭代步骤发生了变化。 Below is the hyperparameters, I'm trying out:以下是超参数,我正在尝试:

batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)

Any inputs will be really helpful.任何输入都会非常有帮助。

I resolved the issue using a larger batch size.我使用更大的批量解决了这个问题。 The error occurs when there are no triplets in a batch.当批次中没有三元组时会发生错误。

This happens due to having a batch with no positive pairs.发生这种情况是因为有一批没有正对。 I have created a packaged for generating compatible batches.我创建了一个用于生成兼容批次的包。 https://github.com/ma7555/kerasgen https://github.com/ma7555/kerasgen

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM