简体   繁体   中英

Keras custom loss function error: 'AttributeError: 'function' object has no attribute 'get_shape'

I have to write my own custom loss functions that can take different inputs other than y_true and y_pred arguments in Keras. After reading some workarounds I've decided to use inner functions as follows:

from keras import backend as K

lambda_prn_regr = 0.6
lambda_prn_vis = 0.2
lambda_prn_class = 0.2

epsilon = 1e-4

# Person loss
def prn_loss_cls(y_true, y_pred):
    def prn_loss_cls_fixed_num(y_true, y_pred):
        # lambda * b_ce
        return lambda_prn_class * K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
    return prn_loss_cls_fixed_num

# Regression loss
def prn_loss_regr(num_joints):
    def prn_loss_regr_fixed_num(y_true, y_pred):
        # lambda * sum(vis * (pose_pred - pose_true)^2) / sum(vis)
        return lambda_prn_regr * K.sum(y_true[:, :, :, :2*num_joints] * K.square(y_pred - y_true[:, :, :, 2*num_joints:])) / K.sum(y_true[:, :, :, :2*num_joints])
    return prn_loss_regr_fixed_num

# Visibility Loss
def prn_loss_vis(y_true, y_pred):
    def prn_loss_regr_fixed_num(y_true, y_pred):
        return lambda_prn_vis * K.mean(K.square(y_pred - y_true), axis=-1)
    return prn_loss_regr_fixed_num

Three different loss functions: each of them has weights and one requires an integer argument.

But I got AttributeError: 'function' object has no attribute 'get_shape' error while executing model.compile function. Whole error output as follows:

Traceback (most recent call last):
  File "train_mppn.py", line 97, in <module>
    model_prn.compile(optimizer=optimizer, loss=[losses.prn_loss_cls, losses.prn_loss_regr(C.num_joints), losses.prn_loss_vis(C.num_joints)])
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 899, in compile
    sample_weight, mask)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 441, in weighted
    ndim = K.ndim(score_array)
  File "/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.py", line 439, in ndim
    dims = x.get_shape()._dims
AttributeError: 'function' object has no attribute 'get_shape'

Compile part:

model.compile(optimizer=optimizer, loss=[losses.prn_loss_cls, losses.prn_loss_regr(num_joints), losses.prn_loss_vis])

I can't find the source of problem.

You are passing functions that don't return values, they return functions.

It's understandable that you do that in the num_joints case (and you're actually calling that function), but it's just weird in the other cases, especially because you're not calling them anywhere to return the inner function.

Suggestion:

# Person loss
def prn_loss_cls(y_true, y_pred):
       return lambda_prn_class * K.mean(K.binary_crossentropy(y_true,y_pred), axis=-1)


# Visibility Loss
def prn_loss_vis(y_true, y_pred):
    return lambda_prn_vis * K.mean(K.square(y_pred - y_true), axis=-1)

If I understand correctly, prn_loss_cls, prn_loss_regr, and prn_loss_vis are functors, ie functions that return function. And you want to use the returned function as loss function. So you need to call these functors, not just link them to the loss, eg

model.compile(optimizer=optimizer, loss=[losses.prn_loss_cls(), losses.prn_loss_regr(num_joints), losses.prn_loss_vis()]

Hope this works :)

do not include y_true and y_pred in the outer functions, you should only include variables you need in the outer, and only use y_true and y_pred for the inner functions . ie they should be defined as:

# Person loss
def prn_loss_cls():
    def prn_loss_cls_fixed_num(y_true, y_pred):
        # lambda * b_ce
        return lambda_prn_class * K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
    return prn_loss_cls_fixed_num

def prn_loss_vis():
    def prn_loss_regr_fixed_num(y_true, y_pred):
        return lambda_prn_vis * K.mean(K.square(y_pred - y_true), axis=-1)
    return prn_loss_regr_fixed_num

regression loss is fine. Then you should be able to compile your model with

model_prn.compile(optimizer=optimizer, loss=[losses.prn_loss_cls(), losses.prn_loss_regr(C.num_joints), losses.prn_loss_vis()])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM