简体   繁体   中英

Implementing accuracy for triplet loss in keras

I want to implement an accuracy function for a triplet loss network so that I know, how does the algorithm works during the training. So far I have tried something, but I'm not sure whether it actually can work and also I have troubles implementing it in keras. My idea was to compare the predicted anchor-positive and anchor-negative distances (in y_pred ), so that the positive distance should be low enough and the negative one large enough:

def accuracy(_, y_pred):
    pos_treshold = 0.4
    neg_treshold = 0.6
    return K.mean(y_pred[0] < pos_treshold and y_pred[1] > neg_treshold)

The problem with this is that I couldn't figure out how to implement this and condition in keras.

Then I tried to find something on this topic of accuracy for triplet loss. One way of doing it is to define the accuracy as a proportion of the number of triplets in which the predicted distance between the anchor image and the positive image is less than the one between the anchor image and the negative image. With this I have even bigger problems in implementing it in keras.

I tried this (although I don't know whether it does what I described):

K.mean(y_pred[0] < y_pred[1])

which gives me accuracy around 0.5 all the time (probably some random stuff). So still I don't know whether the model is bad or the accuracy function is bad.

So my question is how to implement any reasonable accuracy function in keras? Whether it would be one of these two I don't really care.

That's what I use (condition y_pred[0] < y_pred[1]), while taking into account the batch dimension. Note that I'm not using a mean , so that it would support sample-weight.

def triplet_accuracy(_, y_pred):
    '''
        Input:  y_pred shape is (batch_size, 2)
                [pos, neg]
        Output: shape (batch_size, 1)
                loss[i] = 1 if y_pred[i, 0] < y_pred[i, 1] else 0
    '''

    subtraction = K.constant([-1, 1], shape=(2, 1))
    diff =  K.dot(y_pred, subtraction)
    loss = K.maximum(K.sign(diff), K.constant(0))

    return loss

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM