简体   繁体   中英

Targets in custom loss and custom metric function in tf.Keras

I am creating a tf.keras.model which is compiled with a custom loss and a custom metrics function. I call train_on_batch on model using x=input_batch and y=someFunction(targets)

The signature of custom loss and custom metrics functions looks like methodname(y_true,y_pred) Here y_true is fed with someFunction(targets)

Is there any way to get targets in custom metrics function and custom loss function rather than the modified targets which are passed in train_on_batch

Is pushing your target transform function inside of the custom loss an option? Otherwise, your training loops may never have access to the pre-transform labels.

def scale_label(y):
  return y * 0.1

def build_loss_fn(label_transform):
    def my_loss_fn(y_true, y_pred):
        new_y_pred = label_transform(y_pred)
        squared_difference = tf.square(y_true - new_y_pred)
        return tf.reduce_mean(squared_difference, axis=-1)  # Note the `axis=-1`
    return my_loss_fn

my_transformed_loss_fn = build_loss_fn(scale_label)

model.compile(optimizer='adam', loss=my_transform_loss_fn)

# Fit with raw labels now, train_on_batch sees raw labels

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM