繁体   English   中英

三元组损失的损失 nan Tensorflow

[英]Loss nan for triplet loss Tensorflow

我正在尝试在 TensorFlow (tensorflow==1.13.1) 中训练三元组损失 model,但在第一次迭代后损失转到 NAN。 我已经从该论坛阅读了很多提示,但到目前为止没有任何效果。 较小的学习率(从 0.01 更改为 0.0001)。 还尝试了更大的批次大小,因为有人建议如果大小太小,批次可能不包含完整的三元组。 但是,我自己用这段代码构建了批次:'''

def get_triplets(self):
    a = random.choice(self.anchors)
    p = random.choice(self.positives[a])
    n = random.choice(self.negatives[a])

    return self.images[a], self.images[p], self.images[n]

def get_triplets_batch(self, n):
    idxs_a, idxs_p, idxs_n = [], [], []
    for _ in range(n):
        a, p, n = self.get_triplets()
        idxs_a.append(a)
        idxs_p.append(p)
        idxs_n.append(n)
    return idxs_a, idxs_p, idxs_n

我还看到人们在预测中添加一个小数字来避免这个 NAN。 我不知道我是否必须为三重态损失或在哪里这样做。 训练代码是这样的

'''

_, l, summary_str = sess.run([train_step, loss, merged], feed_dict={anchor_input: batch_anchor, positive_input: batch_positive, negative_input: batch_negative})

我是 StackOverflow 的新手,所以我想包括我在 Github 中使用的文件,以防我提供的信息太少。 https://github.com/AndreasVer/Triplet-loss-tensorflow.git

我已经发布了一个 package 用于三元组生成以避免 NaN 损失。 当您的批次没有任何正对时,通常会发生这种情况。

https://github.com/ma7555/kerasgen (免责声明:我是所有者)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM