简体   繁体   中英

Using keras callback to make prediction on the current batch

I am trying to use a keras callbck to make prediction at the end of the batch as follows :

from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
import numpy as np


class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.inputs = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches
        self.preds = []

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_input = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)


    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        self.targets.append(K.eval(self.var_y_true))
        batch_inp = K.eval(self.var_input)
        self.inputs.append(batch_inp)
        self.outputs.append(K.eval(self.var_y_pred))
        current_pred = self.model.predict(batch_inp)
        self.preds.append(current_pred)


# build a simple model
K.clear_session()
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(2,)), Dense(2)])
model.compile(loss='mse', optimizer='adam')

# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
           tf.assign(cbk.var_input, model.inputs[0], validate_shape=False),
           tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}  # use `model._function_kwargs` if using `Model` instead of `Sequential`


# fit the model and check results
X = np.arange(10).reshape((5, 2))
Y = X*2

model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)

And I am getting the following error :


InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-114-adfad08009ad> in <module>
      3 Y = X*2
      4 
----> 5 model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    202                     batch_logs[l] = o
    203 
--> 204                 callbacks.on_batch_end(batch_index, batch_logs)
    205                 if callback_model.stop_training:
    206                     break

/usr/local/lib/python3.6/dist-packages/keras/callbacks.py in on_batch_end(self, batch, logs)
    113         t_before_callbacks = time.time()
    114         for callback in self.callbacks:
--> 115             callback.on_batch_end(batch, logs)
    116         self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
    117         delta_t_median = np.median(self._delta_ts_batch_end)

<ipython-input-111-65feb418f9ec> in on_batch_end(self, batch, logs)
     19         self.inputs.append(batch_inp)
     20         self.outputs.append(K.eval(self.var_y_pred))
---> 21         current_pred = self.model.predict(batch_inp)
     22         self.preds.append(current_pred)

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in predict_on_batch(self, x)
   1272             ins = x
   1273         self._make_predict_function()
-> 1274         outputs = self.predict_function(ins)
   1275         return unpack_singleton(outputs)
   1276 

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_target' with dtype float and shape [?,?]
     [[{{node dense_2_target}}]]

I am able to get the model output at the end of the batch from the variable self.var_y_pred which is assigned to model.outputs[0] .

However from my understanding this prediction is done before the backpropagation at the current step. And my objective is to be able to make prediction on the current batch using the model version which weights were already updated with current batch training.

How can I achieve this ?

The answer is "you can't".

The objects model.inputs and model.outputs are lists of "tensors", not data. Tensors are void graph representations.

The only way to get a batch prediction is calling model.predict_on_batch(input_data_as_numpy) or similar methods. This means making the model predict the same thing twice in your case. A terrible performance drawback.

For using predicted batches during training you need to switch to using eager mode on and make a custom training loop:

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM