简体   繁体   中英

How to use Keras' predict_on_batch in tf.data.Dataset.map()?

I would like to find a way to use Keras' predict_on_batch inside tf.data.Dataset.map() in TF2.0.

Let's say I have a numpy dataset

n_data = 10**5
my_data    = np.random.random((n_data,10,1))
my_targets = np.random.randint(0,2,(n_data,1))

data = ({'x_input':my_data}, {'target':my_targets})

and a tf.keras model

x_input = Input((None,1), name = 'x_input')
RNN     = SimpleRNN(100,  name = 'RNN')(x_input)
dense   = Dense(1, name = 'target')(RNN)

my_model = Model(inputs = [x_input], outputs = [dense])
my_model.compile(optimizer='SGD', loss = 'binary_crossentropy')

I can create a batched dataset with

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)
prediction_dataset = dataset.map(transform_predictions)

where transform_predictions is a user defined function that gets the predictions from predict_on_batch

def transform_predictions(inputs, outputs):
    predictions = my_model.predict_on_batch(inputs)
    # predictions = do_transformations_here(predictions)
    return predictions

This gives an error from predict_on_batch :

AttributeError: 'Tensor' object has no attribute 'numpy'

As far as I understand, predict_on_batch expects a numpy array, and it is getting a tensor object from the dataset.

It seems like one possible solution is to wrap predict_on_batch in a `tf.py_function, though I have not been able to get that working either.

Does anyone know how to do this?

Dataset.map() returns <class 'tensorflow.python.framework.ops.Tensor'> which doesn't have numpy() method.

Iterating over Dataset returns <class 'tensorflow.python.framework.ops.EagerTensor'> which has a numpy() method.

Feeding an eager tensor to predict() family of methods works fine.

You could try something like this:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM