简体   繁体   中英

“LookupError: Function `__class__` does not exist.” When using tf.function

I am trying to override a keras/tf2.0 loss function with a custom one abstracted in a WebAssembly binary. Here is the relevant code.

@tf.function
def custom_loss(y_true, y_pred):
    return tf.constant(instance.exports.loss(y_true, y_pred))

and I am using it this way

model.compile(optimizer_obj, loss=custom_loss)
# and then model.fit(...)

I am not entirely sure how the tf2.0 eager execution works, so any insights regarding that would be useful.

I don't think the instance.exports.loss function is relevant to the error, however if you are sure everything else is alright, let me know and I will share additional details.

Here is a stacktrace and the actual error: https://pastebin.com/6YtH75ay

First of all, you need not use @tf.function in order to define a custom loss.

We can happily (if pointlessly) do something like this:

def custom_loss(y_true, y_pred):
  return tf.reduce_mean(y_pred)

model.compile(optimizer='adam',
              loss=custom_loss,
              metrics=['accuracy'])

so long as all the operations we use inside custom_loss are differentiable by tensorflow

So you can remove the @tf.function decorator but then I suspect you will then run into an error message something like this:

[Some trace info] - Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

because there is no way for tensorflow to be able to find the gradients of a function in a webassembly binary. Everything inside that loss function must be something tensorflow can understand and compute gradients on otherwise it will not be able to optimise for a lower loss value.

Perhaps the best way forward would be to replicate the functionality inside instance.exports.loss using operations that tensorflow can compute gradients on rather than trying to reference it directly?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM