繁体   English   中英

Tensorflow 2.0:自定义 keras 公制导致 tf.function 回溯警告

[英]Tensorflow 2.0: custom keras metric caused tf.function retracing warning

当我使用以下自定义指标(keras 样式)时:

from sklearn.metrics import classification_report, f1_score
from tensorflow.keras.callbacks import Callback

class Metrics(Callback):
    def __init__(self, dev_data, classifier, dataloader):
        self.best_f1_score = 0.0
        self.dev_data = dev_data
        self.classifier = classifier
        self.predictor = Predictor(classifier, dataloader)
        self.dataloader = dataloader

    def on_epoch_end(self, epoch, logs=None):
        print("start to evaluate....")
        _, preds = self.predictor(self.dev_data)
        y_trues, y_preds = [self.dataloader.label_vector(v["label"]) for v in self.dev_data], preds
        f1 = f1_score(y_trues, y_preds, average="weighted")
        print(classification_report(y_trues, y_preds,
                                    target_names=self.dataloader.vocab.labels))
        if f1 > self.best_f1_score:
            self.best_f1_score = f1
            self.classifier.save_model()
            print("best metrics, save model...")

我收到以下警告:

W1106 10:49:14.171694 4745115072 def_function.py:474] 在 0x14a3f9d90> 的最后 11 次调用中,有 6 次触发了 tf.function 回溯。 跟踪很昂贵,并且过多的跟踪可能是由于传递了 python 对象而不是张量。 此外, tf.function 具有 Experimental_relax_shapes=True 选项,可以放宽可以避免不必要的回溯的参数形状。 Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

在导入 tensorflow 后添加这一行:

tf.compat.v1.disable_eager_execution()

当回溯 TF function 时会出现此警告,因为其 arguments 的形状或 dtype(对于张量)甚至值(Python 或 np 对象或变量)发生变化。

在一般情况下,解决方法是在您传递给 Keras 或 TF 的自定义 function 的定义之前使用 @tf.function(experimental_relax_shapes=True) 。 这试图检测并避免不必要的回溯,但不能保证解决问题。

在您的情况下,我猜 Predictor class 是自定义 class,所以在 Predictor.predict() 的定义之前放置 @tf.function(experimental_relax_shapes=True)。

然后使用 @tf.function(experimental_relax_shapes=True) 可能会解决你的问题

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM