简体   繁体   中英

Custom loss function for predicting interget outputs?

I'm currenly working on a dataset where I've to predict an integer output. It starts from 1 to N. I've build a network with loss function mse. But I feel like mse loss function may not be an ideal loss function to minimize in the case of integer output.

I'm also round my prediction to get integer output. Is there a way to make/optimize the model better in case of integer output.

Can anyone provide some help on how to deal with integer output/targets. This is the loss function I'm using right now.

model.compile(optimizer=SGD(0.001), loss='mse')

You are using the wrong loss, mean squared error is a loss for regression, and you have a classification problem (discrete outputs, not continuous).

So for this your model should have a softmax output layer:

model.add(Dense(N, activation="softmax"))

And you should be using a classification loss:

model.compile(optimizer=SGD(0.001), loss='sparse_categorical_crossentropy')

Assuming your labels are integers in the [0, N-1] range (off by one), this should work. To make a prediction, you should do:

output = np.argmax(model.predict(some_data), axis=1) + 1

The +1 is because integer labels go from 0 to N-1

Ordinal regression could be an appropriate approach, in case predicting the wrong month but close to the true month is considered a smaller mistake than predicting a value one year earlier or later. Only you can know that, based on the specific problem you want to solve.

I found an implementation of the appropriate loss function on github (no affiliation). For completeness, below I copy-paste the code from that repo:

from keras import backend as K
from keras import losses

def loss(y_true, y_pred):
    weights = K.cast(
        K.abs(K.argmax(y_true, axis=1) - K.argmax(y_pred, axis=1))/(K.int_shape(y_pred)[1] - 1),
        dtype='float32'
    )
    return (1.0 + weights) * losses.categorical_crossentropy(y_true, y_pred)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM