简体   繁体   English

如何 model.predict 内部损失 function? (张量流,Keras)

[英]How to model.predict inside loss function? (Tensorflow, Keras)

I am trying to construct a custom loss for a regression problem with the following structure, following this answer: Keras Custom loss function to pass arguments other than y_true and y_pred我正在尝试为具有以下结构的回归问题构建自定义损失,遵循此答案: Keras 自定义损失 function 以传递 arguments 而非 y_true 和 y_pred

Now, my function is like the following:现在,我的 function 如下所示:

def CustomLoss(model,X_valid,y_valid,batch_size):
    def Loss(y_true,y_pred):
        n_samples=5
        mc_predictions = np.zeros((n_samples,256,256))
        for i in range(n_samples):
           y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
    (Other operations...) 
        return LossValue
    return Loss

When trying to execute this line y_p = model.predict(X_valid, verbose=1,batch_size=batch_size) i get the following error:尝试执行此行时y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)我收到以下错误:

Method requires being in cross-replica context, use get_replica_context().merge_call()方法需要在跨副本上下文中,使用 get_replica_context().merge_call()

From what I gathered I cannot use model.predict inside loss function.根据我收集到的信息,我无法在损失 function 中使用 model.predict。 Is there a workaround or solution for this?是否有解决方法或解决方案? Please let me know if my question is clear or if you need any additional information.如果我的问题很清楚,或者您需要任何其他信息,请告诉我。 Thanks!谢谢!

Sounds like you can use model.add_loss for this.听起来您可以为此使用 model.add_loss 。 You can use this to specify the loss function inside of the model.您可以使用它来指定 model 内部的损失 function。 It also removes the need for the loss function to only take in y and y_pred.它还消除了对损失 function 仅采用 y 和 y_pred 的需要。 https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss \ Some psuedo-code: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss \ 一些伪代码:

class YourModel(tf.keras.Model): 
    ...
    def call(self, inputs): 
        unpack, any, extra, stuff = inputs
        (your network code goes here)
        loss = (other operations)
        self.add_loss(loss)
        return output

(In case you don't know, model.predict is basically just model.call but with some extra bells and whistles attached.) (如果您不知道,model.predict 基本上只是 model.call 但附加了一些额外的花里胡哨。)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM