简体   繁体   中英

How to use model input in loss function?

I am trying to use a custom loss-function which depends on some arguments that the model does not have.

The model has two inputs ( mel_specs and pred_inp ) and expects a labels tensor for training:

def to_keras_example(example):
    # Preparing inputs
    return (mel_specs, pred_inp), labels

# Is a tf.train.Dataset for model.fit(train_data, ...)
train_data = load_dataset(fp, 'train).map(to_keras_example).repeat()

In my loss function I need to calculate the lengths of mel_specs and pred_inp . This means my loss looks like this:

def rnnt_loss_wrapper(y_true, y_pred, mel_specs_inputs_):
    input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
    label_lengths = get_padded_length(y_true)
    return rnnt_loss(
        acts=y_pred,
        labels=tf.cast(y_true, dtype=tf.int32),
        input_lengths=input_lengths,
        label_lengths=label_lengths
    )

However, no matter which approach I choose, I am facing some issue.


Option 1) Setting the loss-function in model.compile()

If I actually wrap the loss function st it returns a function which takes y_true and y_pred like this:

def rnnt_loss_wrapper(mel_specs_inputs_):
    def inner_(y_true, y_pred):
        input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
        label_lengths = get_padded_length(y_true)
        return rnnt_loss(
            acts=y_pred,
            labels=tf.cast(y_true, dtype=tf.int32),
            input_lengths=input_lengths,
            label_lengths=label_lengths
        )
    return inner_

model = create_model(hparams)
model.compile(
    optimizer=optimizer,
    loss=rnnt_loss_wrapper(model.inputs[0]
)

Here I get a _SymbolicException after calling model.fit() :

tensorflow.python.eager.core._SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [...]

Option 2) Using model.add_loss()

The documentation of add_loss() states:

 [Adds a..] loss tensor(s), potentially dependent on layer inputs. .. Arguments: losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses may also be zero-argument callables which create a loss tensor. inputs: Ignored when executing eagerly. If anything...

So I tried to do the following:

def rnnt_loss_wrapper(y_true, y_pred, mel_specs_inputs_):
    input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
    label_lengths = get_padded_length(y_true)
    return rnnt_loss(
        acts=y_pred,
        labels=tf.cast(y_true, dtype=tf.int32),
        input_lengths=input_lengths,
        label_lengths=label_lengths
    )

model = create_model(hparams)
model.add_loss(
    rnnt_loss_wrapper(
        y_true=model.inputs[2],
        y_pred=model.outputs[0],
        mel_specs_inputs_=model.inputs[0],
    ),
    inputs=True
)
model.compile(
    optimizer=optimizer
)

However, calling model.fit() throws a ValueError :

ValueError: No gradients provided for any variable: [...]

Is any of the above options supposed to work?

I have used the add_loss method as follow:

def custom_loss(y_true, y_pred, input_):
# custom loss function
    y_estim = input_[...,0]*y_pred
    shape = tf.cast(tf.shape(y_true)[1], dtype='float32')
    return tf.reduce_mean(1/shape*tf.reduce_sum(tf.pow(y_true-y_estim, 2), axis=1))


mix_input = layers.Input(shape=(301, 257, 4)) # input 1
ref_input = layers.Input(shape=(301, 257, 1)) # input 2
target = layers.Input(shape=(301, 257))       # output target

smss_model = Model(inputs=[mix_input, ref_input], outputs=smss) # my model that accept two inputs

model = Model(inputs=[mix_input, ref_input, target], outputs=smss) # this one used just to train the model, with the additional paramters
model.add_loss(custom_loss(target, smss, mix_input)) # the add_loss where to pass the custom loss function
model.summary()

model.compile(loss=None, optimizer='sgd')
model.fit([mix, ref, y], epochs=1, batch_size=1, verbose=1)

even do I have used this method and works, I still looking for another method, that not involve creating a training model

Did using lambda function work? ( https://www.w3schools.com/python/python_lambda.asp )

loss = lambda x1, x2: rnnt_loss(x1, x2, acts, labels, input_lengths,
                                label_lengths, blank_label=0)

In this way your loss function should be a function accepting parameters x1 and x2 , but rnnt_loss can also be aware of acts , labels , input_lengths , label_lengths and blank_label

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM