简体   繁体   中英

Tensorflow training - print multiple losses for one output

I would like to print all the different losses I have for one output separately. At the moment it looks like:

1/1 [==============================] - 1s 1s/sample - loss: 4.2632

The goal is to have a history like:

1/1 [==============================] - 1s 1s/sample - loss1: 2.1, loss2: 2.1632

I have one output layer out1 and two loss functions loss1 and loss2.

def loss1(y_true, y_pred):
    ...
    return ...
def loss2(y_true, y_pred):
    ...
    return ...

When I do

model.compile(...)

I can either choose to have a single loss function,

model.compile(loss=lambda x: loss1(x) + loss2(x))

or defining a loss for each output in a dictionary

model.compile(loss={'out1': loss1(x), 'out2': loss2(x)})

Since I have only one output, this isn't an option for me. Does anyone know how to print the losses separately when having only one output?

Just use the metrics argument:

model.compile(optimizer='adam', loss='mae', metrics=['mse'])

You will still need to choose one loss to minimize.

One workaround is to artificially create the same two outputs, and then combine them with weights equal 1. For the sake of concreteness, I wrote the example:

from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.losses import mse, mae
import numpy as np

if __name__ == '__main__':
    train_x = np.random.rand(10000, 200)
    train_y = np.random.rand(10000, 1)

    x_input = Input(shape=(200))
    x = Dense(64)(x_input)
    x = Dense(64)(x)
    x = Dense(1)(x)

    x1 = Lambda(lambda x: x, name='out1')(x)
    x2 = Lambda(lambda x: x, name='out2')(x)

    model = Model(inputs=x_input, outputs=[x1, x2])

    model.compile(optimizer='adam', loss={'out1': mse, 'out2': mae}, loss_weights={'out1': 1, 'out2': 1})

    model.fit(train_x, train_y, epochs=10)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM