![](/img/trans.png)
[英]training using triplet loss: Nan in summary histogram tensorflow
[英]Loss nan for triplet loss Tensorflow
我正在尝试在 TensorFlow (tensorflow==1.13.1) 中训练三元组损失 model,但在第一次迭代后损失转到 NAN。 我已经从该论坛阅读了很多提示,但到目前为止没有任何效果。 较小的学习率(从 0.01 更改为 0.0001)。 还尝试了更大的批次大小,因为有人建议如果大小太小,批次可能不包含完整的三元组。 但是,我自己用这段代码构建了批次:'''
def get_triplets(self):
a = random.choice(self.anchors)
p = random.choice(self.positives[a])
n = random.choice(self.negatives[a])
return self.images[a], self.images[p], self.images[n]
def get_triplets_batch(self, n):
idxs_a, idxs_p, idxs_n = [], [], []
for _ in range(n):
a, p, n = self.get_triplets()
idxs_a.append(a)
idxs_p.append(p)
idxs_n.append(n)
return idxs_a, idxs_p, idxs_n
我还看到人们在预测中添加一个小数字来避免这个 NAN。 我不知道我是否必须为三重态损失或在哪里这样做。 训练代码是这样的
'''
_, l, summary_str = sess.run([train_step, loss, merged], feed_dict={anchor_input: batch_anchor, positive_input: batch_positive, negative_input: batch_negative})
我是 StackOverflow 的新手,所以我想包括我在 Github 中使用的文件,以防我提供的信息太少。 https://github.com/AndreasVer/Triplet-loss-tensorflow.git
我已经发布了一个 package 用于三元组生成以避免 NaN 损失。 当您的批次没有任何正对时,通常会发生这种情况。
https://github.com/ma7555/kerasgen (免责声明:我是所有者)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.