简体   繁体   English

在 keras 中实现三元组损失的准确性

[英]Implementing accuracy for triplet loss in keras

I want to implement an accuracy function for a triplet loss network so that I know, how does the algorithm works during the training.我想为三元组损失网络实现一个准确度函数,以便我知道算法在训练过程中是如何工作的。 So far I have tried something, but I'm not sure whether it actually can work and also I have troubles implementing it in keras.到目前为止,我已经尝试了一些东西,但我不确定它是否真的可以工作,而且我在 keras 中实现它也遇到了麻烦。 My idea was to compare the predicted anchor-positive and anchor-negative distances (in y_pred ), so that the positive distance should be low enough and the negative one large enough:我的想法是比较预测的锚正距离和锚负距离(在y_pred ),以便正距离应该足够低,负距离应该足够大:

def accuracy(_, y_pred):
    pos_treshold = 0.4
    neg_treshold = 0.6
    return K.mean(y_pred[0] < pos_treshold and y_pred[1] > neg_treshold)

The problem with this is that I couldn't figure out how to implement this and condition in keras.问题在于我无法弄清楚如何在 keras 中实现这一点and条件。

Then I tried to find something on this topic of accuracy for triplet loss.然后我试图找到关于三元组损失准确性这个主题的东西。 One way of doing it is to define the accuracy as a proportion of the number of triplets in which the predicted distance between the anchor image and the positive image is less than the one between the anchor image and the negative image.一种方法是将准确率定义为锚图像和正图像之间的预测距离小于锚图像和负图像之间的预测距离的三元组数量的比例。 With this I have even bigger problems in implementing it in keras.有了这个,我在 keras 中实现它时遇到了更大的问题。

I tried this (although I don't know whether it does what I described):我试过这个(虽然我不知道它是否符合我的描述):

K.mean(y_pred[0] < y_pred[1])

which gives me accuracy around 0.5 all the time (probably some random stuff).这让我的准确度一直在 0.5 左右(可能是一些随机的东西)。 So still I don't know whether the model is bad or the accuracy function is bad.所以还是不知道是模型不好还是精度函数不好。

So my question is how to implement any reasonable accuracy function in keras?所以我的问题是如何在 keras 中实现任何合理的准确度函数? Whether it would be one of these two I don't really care.是否会是这两者之一,我并不在乎。

That's what I use (condition y_pred[0] < y_pred[1]), while taking into account the batch dimension.这就是我使用的(条件 y_pred[0] < y_pred[1]),同时考虑到批次维度。 Note that I'm not using a mean , so that it would support sample-weight.请注意,我没有使用mean ,因此它将支持样本权重。

def triplet_accuracy(_, y_pred):
    '''
        Input:  y_pred shape is (batch_size, 2)
                [pos, neg]
        Output: shape (batch_size, 1)
                loss[i] = 1 if y_pred[i, 0] < y_pred[i, 1] else 0
    '''

    subtraction = K.constant([-1, 1], shape=(2, 1))
    diff =  K.dot(y_pred, subtraction)
    loss = K.maximum(K.sign(diff), K.constant(0))

    return loss

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM