简体   繁体   中英

Error in batch size with custom loss function in Keras

I'm working on a detector with Keras, where the output y_true consists in a vector "y" with 500 values, which contains a pulse that indicates the time of the event detected within 500 samples from a signal.

Ex: y=[0, 0, 0,....,0,1,1,1,1,1,1,1,1,1,1,1,0,....0,0,0]

I've worked before with the 'mse' for the loss, and it works, but I want to use a loss function that considers the distance between the middle value from the pulse in y_true and the max value in y_pred. Later I use the max value in the y_pred to normalize it and define the pulse around it.

Since I can't work with just the distance and make it differentiable, I defined this custom loss function, which weights the mean square error with the estimated distance.

import tensorflow as tf
import keras.backend as kb

def custom_loss_function (y_true, y_pred):
    
    t_label = []
    t_picking = 0    
    t_label = tf.where(y_true == 1)[:,0]    
    mayor = tf.reduce_max(y_pred)
    t_picking = tf.where(y_pred == mayor)[:,0]    
    d = tf.cast(abs(t_label[5]-t_picking)/50,tf.float32)

    loss = (kb.mean(kb.square(y_true-y_pred)))*d

    return loss

Where t_label[5] and t_picking are the middle value of the pulse in y_trye and the max value in y_pred respectively. And d is the distance between them.

I compiled the model with this loss function, using Adam optimizer and a batch size of 64. Everything works, and the model can be compiled, but I get this error in the middle of the training:

InvalidArgumentError:  Incompatible shapes: [64] vs. [2]
 [[node Adam/gradients/gradients/loss/dense_1_loss/custom_loss_function/weighted_loss/mul_grad/BroadcastGradientArgs (defined at C:\Users\Maca\anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_2220]

I've tried before with other custom loss functions and didn't have this problem, but I can't see where's the error is coming from.

Do you know why am I getting this error and how can I fix it?

There are two equal max value in a particular batch. So your t_picking sometimes (rarely) has two (or even more) values instead of one.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM