简体   繁体   中英

CNN loss with multiple outputs?

I have the following model

def get_model():
epochs = 100
learning_rate = 0.1
decay_rate = learning_rate / epochs

inp = keras.Input(shape=(64, 101, 1), name="inputs")
x = layers.Conv2D(128, kernel_size=(3, 3), strides=(3, 3), padding="same")(inp)
x = layers.Conv2D(256, kernel_size=(3, 3), strides=(3, 3), padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(150)(x)
x = layers.Dense(150)(x)
out1 = layers.Dense(40000, name="sf_vec")(x)
out2 = layers.Dense(128, name="ls_weights")(x)

model = keras.Model(inp, [out1, out2], name="2_out_model")

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=decay_rate),  # in caso rimettere 0.001
              loss="mean_squared_error")

keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
model.summary()

return model

在此处输入图像描述

that is, I want to train my neural network based on the "mix" of the loss from the first output and the loss from the second output. I train my neural network in this way:

model.fit(x_train, [sf_train, ls_filters_train], epochs=10)

and during the training,for example, this is shown: Epoch 10/10 -> loss: 0.0702 - sf_vec_loss: 0.0666 - ls_weights_loss: 0.0035

I'd like to know if it's a case that the "loss" is nearly the sum between the sf_vec_loss and ls_weights_loss or if keras is actually reasoning in this way. Also, is the network being trained on the "loss" only? Thank you in advance:)

following the Tensorflow Documentation ...

from the loss argument:

If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses

remember also that you can also weight the loss contributions of different model outputs

from the loss_weights argument:

The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM