繁体   English   中英

使用三元组损失进行训练:摘要直方图中的 Nan tensorflow

[英]training using triplet loss: Nan in summary histogram tensorflow

我正在尝试使用三元组损失来训练 CNN model。 我有 8 个类(产品)的图片,每个 class 都有大约 100 张图片
The.network 架构看起来像:

input image   ->          conv1   ->   conv2    ->   conv3   ->   conv4   -> conv5 -> 28D embedding
 182x182     filters       7x7         5x5            3x3          1x1        1x1
             num_outputs    32          64            128          256         28
             activation     Relu       Relu           Relu         Relu       Relu

因此,.network给出了一个28维的embedding。 然而,在训练期间,它在随机迭代步骤中抛出以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1

我已经尝试过一些超参数但仍然没有运气,只有它抛出错误的迭代步骤发生了变化。 以下是超参数,我正在尝试:

batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)

任何输入都会非常有帮助。

我使用更大的批量解决了这个问题。 当批次中没有三元组时会发生错误。

发生这种情况是因为有一批没有正对。 我创建了一个用于生成兼容批次的包。 https://github.com/ma7555/kerasgen

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM