简体   繁体   中英

How to custom tensorflow loss function with filtering?

I know we can create the custom loss function like the following method.

def custom_loss(y_true, y_pred): 
    y_pred = K.round(y_pred / 1000) * 1000 # Rounded as 1000 unit
    loss = tf.keras.losses.MSE(y_true, y_pred) 
    return K.sqrt(loss)

model = tf.keras.Sequential()
model.add(feature_layer)
model.add(layers.Dense(1, activation="relu"))
model.compile(loss='mse', optimizer= opt, 
                metrics = [tf.keras.metrics.RootMeanSquaredError(), custom_loss])
opt = tf.keras.optimizers.Adam(learning_rate= alpha)

However, I don't know how can we use the filter inside the custom loss function (since it looks like only support the Keras backend function.)

For the filter function example, only calculate the loss when the y_true >= 1000 .

Any suggestion? I would like to monitor the filtered custom loss function during the training.

Thank you

You can use tensorflow and tensorflow.keras.backend methods to achieve this.

import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as kb

x = np.array(range(1,11))
y = 2*x
x = x.reshape(2,5)
y = y.reshape(2,5)
x = x.astype(np.float32)
y = y.astype(np.float32)
def custom_loss(y_true, y_pred): 
    # Calculate difference only if condition is met, else assign 0
    diff = tf.where(y_true >= 5, y_true - y_pred, 0)
    sum_of_squares = kb.sum(kb.square(diff),axis=-1)
    # count of values where diff != 0
    value_counts = kb.sum(tf.where(diff != 0, 1, 0),axis=-1)
    value_counts = tf.cast(value_counts,sum_of_squares.dtype)
    custom_loss = sum_of_squares/value_counts
    custom_loss = kb.sqrt(custom_loss)
    return custom_loss
tf.random.set_seed(52)
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10,activation='relu'),
    tf.keras.layers.Dense(1)
])
tf.keras.backend.clear_session()
model.compile(loss='mse',metrics=[tf.keras.metrics.RootMeanSquaredError(), custom_loss])
model.fit(x,y,epochs=10)

Epoch 1/10
1/1 [==============================] - 0s 996us/step - loss: 277.5225 - root_mean_squared_error: 16.6590 - custom_loss: 16.3645
Epoch 2/10
1/1 [==============================] - 0s 997us/step - loss: 264.7605 - root_mean_squared_error: 16.2715 - custom_loss: 16.0073
Epoch 3/10
1/1 [==============================] - 0s 998us/step - loss: 255.8174 - root_mean_squared_error: 15.9943 - custom_loss: 15.7518
Epoch 4/10
1/1 [==============================] - 0s 996us/step - loss: 248.5119 - root_mean_squared_error: 15.7643 - custom_loss: 15.5396
Epoch 5/10
1/1 [==============================] - 0s 998us/step - loss: 242.1583 - root_mean_squared_error: 15.5614 - custom_loss: 15.3526
Epoch 6/10
1/1 [==============================] - 0s 998us/step - loss: 236.4406 - root_mean_squared_error: 15.3766 - custom_loss: 15.1821
Epoch 7/10
1/1 [==============================] - 0s 0s/step - loss: 231.1830 - root_mean_squared_error: 15.2047 - custom_loss: 15.0235
Epoch 8/10
1/1 [==============================] - 0s 997us/step - loss: 226.2768 - root_mean_squared_error: 15.0425 - custom_loss: 14.8739
Epoch 9/10
1/1 [==============================] - 0s 2ms/step - loss: 221.6491 - root_mean_squared_error: 14.8879 - custom_loss: 14.7312
Epoch 10/10
1/1 [==============================] - 0s 999us/step - loss: 217.2489 - root_mean_squared_error: 14.7394 - custom_loss: 14.5941

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM