简体   繁体   中英

Access deprecated attribute "validation_data" in tf.keras.callbacks.Callback

I decided to switch from keras to tf.keras (as recommended here ). Therefore I installed tf.__version__=2.0.0 and tf.keras.__version__=2.2.4-tf . In an older version of my code (using some older Tensorflow version tf.__version__=1.xx ) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here . However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.

class ValMetrics(Callback):

    def on_train_begin(self, logs={}):

        self.val_all_mse = []

    def on_epoch_end(self, epoch, logs):

        val_predict = np.asarray(self.model.predict(self.validation_data[0]))
        val_targ = self.validation_data[1]

        val_epoch_mse = mse_score(val_targ, val_predict)

        self.val_epoch_mse.append(val_epoch_mse)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_epoch_mse"] = val_epoch_mse

        print(f"\nEpoch: {epoch + 1}")
        print("-----------------")
        print("val_mse:     {:+.6f}".format(val_epoch_mse))

        return

My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics class :

class ValMetrics(Callback):

    def __init__(self, validation_data):
        super(Callback, self).__init__()
        self.X_val, self.y_val = validation_data

Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?

Thanks a lot!

You are right that the argument, validation_data is deprecated as per Tensorflow Callbacks Documentation .

The issue which you are facing has been raised in Github. Related issues are Issue1 , Issue2 and Issue3 .

None of the above Github Issues is resolved and Your workaround of passing Validation_Data as an argument to Custom Callback is a good one, as per this Github Comment , as many people found it useful.

Specifying the code of workaround below, for the benefit of the Stackoverflow Community , even though it is present in Github.

class Metrics(Callback):

    def __init__(self, val_data, batch_size = 20):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size

    def on_train_begin(self, logs={}):
        print(self.validation_data)
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
        batches = len(self.validation_data)
        total = batches * self.batch_size

        val_pred = np.zeros((total,1))
        val_true = np.zeros((total))

        for batch in range(batches):
            xVal, yVal = next(self.validation_data)
            val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
            val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal

        val_pred = np.squeeze(val_pred)
        _val_f1 = f1_score(val_true, val_pred)
        _val_precision = precision_score(val_true, val_pred)
        _val_recall = recall_score(val_true, val_pred)

        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)

        return

I will keep following the Github Issues mentioned above and will update the Answer accordingly.

Hope this helps. Happy Learning!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM