简体   繁体   中英

Tensorflow 1.15 / Keras 2.3.1 Model.train_on_batch() returns more values than there are outputs/loss functions

I am trying to train a model that has more than one output and as a result, also has more than one loss function attached to it when I compile it.

I haven't done something similar in the past (not from scratch at least).

Here's some code I am using to figure out how this works.

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

batch_size = 50
input_size = 10

i = Input(shape=(input_size,))
x = Dense(100)(i)
x_1 = Dense(output_size)(x)
x_2 = Dense(output_size)(x)

model = Model(i, [x_1, x_2])

model.compile(optimizer = 'adam', loss = ["mse", "mse"])

# Data creation
x = np.random.random_sample([batch_size, input_size]).astype('float32')
y = np.random.random_sample([batch_size, output_size]).astype('float32')

loss = model.train_on_batch(x, [y,y])

print(loss) # sample output [0.8311912, 0.3519104, 0.47928077]

I would expect the variable loss to have two entries (one for each loss function), however, I get back three. I thought maybe one of them is the weighted average but that does not look to be the case.

Could anyone explain how passing in multiple loss functions works, because obviously, I am misunderstanding something.

I believe the three outputs are the sum of all the losses, followed by the individual losses on each output.

For example, if you look at the sample output you've printed there:

0.3519104 + 0.47928077 = 0.83119117 ≈ 0.8311912

Your assumption that there should be two losses in incorrect. You have a model with two outputs, and you specified one loss for each output, but the model has to be trained on a single loss, so Keras trains the model on a new loss that is the sum of the per-output losses.

You can control how these losses are mixed using the loss_weights parameter in model.compile . I think by default it takes weights values equal to 1.0 .

So in the end what train_on_batch returns is the loss, output one mse, and output two mse. That is why you get three values.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM