[英]training using triplet loss: Nan in summary histogram tensorflow
我正在尝试使用三元组损失来训练 CNN model。 我有 8 个类(产品)的图片,每个 class 都有大约 100 张图片
The.network 架构看起来像:
input image -> conv1 -> conv2 -> conv3 -> conv4 -> conv5 -> 28D embedding
182x182 filters 7x7 5x5 3x3 1x1 1x1
num_outputs 32 64 128 256 28
activation Relu Relu Relu Relu Relu
因此,.network给出了一个28维的embedding。 然而,在训练期间,它在随机迭代步骤中抛出以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: conv2/weights_1
我已经尝试过一些超参数但仍然没有运气,只有它抛出错误的迭代步骤发生了变化。 以下是超参数,我正在尝试:
batch size : varied if from 2 to 12
learning rate : 0.001 - 0.002
momentum: 0.9 (since batch size is small)
training iter: 2000 (it is never reaching that, before only throws an error)
任何输入都会非常有帮助。
我使用更大的批量解决了这个问题。 当批次中没有三元组时会发生错误。
发生这种情况是因为有一批没有正对。 我创建了一个用于生成兼容批次的包。 https://github.com/ma7555/kerasgen
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.