简体   繁体   中英

How does "tfa.losses.triplet_semihard_loss" get called?

In Tensorflow addons there are two mentions of the triplet loss one is the base class tfa.losses.triplet_semihard_loss and the other is tfa.losses.TripletSemiHardLoss which is the child class initialized by the user and in turn implicitly calls the base class. In this chunk of code belonging to the child class:

    def __init__(self, margin=1.0, name=None):
        super(TripletSemiHardLoss, self).__init__(
            name=name, reduction=tf.keras.losses.Reduction.NONE)
        self.margin = margin

    def call(self, y_true, y_pred):
        return triplet_semihard_loss(y_true, y_pred, self.margin)

I do not understand what's going on with the call method, it returns the base class function giving the y_true and y_pred ndarrays but from where exactly do they originate? According to the Tensorflow docs guide the child class is initialized in the model compile statement as:

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

and then the model is fitted as:

history = model.fit(
    train_dataset,
    epochs=5)

the train_dataset structure is a tuple containing the embeddings data and the corresponding integer label, but how does the child class realize that this is the data to operate on? and so is the call method implicitly called as well?

__call__ is called when an instance of the class is called. y_true and y_pred contains the true labels and labels predicted by the model respectively. Tensorflow(tf.keras) internally converts the labels you give to y_true as seen here and trains on the data using model.fit() .
All tf.keras losses are implemented in this form, ie a function with two arguments y_true and y_pred as seen here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM