简体   繁体   中英

How to use a tf.keras model in a tf.data.Dataset generator?

I would like to use a pre-trained Keras model as part of my data processing to generate training data for a second Keras model. Ideally, I would like to do this by calling the first model in the data generator for the second model.

I am using tensorflow 1.15.

A simple example of what I'm trying to do is as follows:

import numpy as np
import tensorflow as tf
from tensorflow_core.python.keras import Sequential
from tensorflow_core.python.keras.layers import Dense

batch_size = 4
data_size = 16

model_generator = Sequential([Dense(data_size)])
model_generator.compile(optimizer='adam', loss='mae')
model_generator.build((batch_size, data_size))
sess = tf.keras.backend.get_session()

def generator():
    while True:
        data = np.random.random((batch_size, data_size))
        targets = tf.random.uniform((batch_size, 1))
        data = model_generator(data, training=False)
        data = data.eval(session=sess)
        yield (data, targets)

model_train = Sequential([Dense(data_size)])
model_train.compile(optimizer='adam', loss='mae')
model_train.build((batch_size, data_size))

output_types = (tf.float64, tf.float64)
output_shapes = (tf.TensorShape((batch_size, data_size)), tf.TensorShape((batch_size, 1)))
dataset = tf.data.Dataset.from_generator(
    generator,
    output_types=output_types,
    output_shapes=output_shapes,
)

if next(generator()) is not None:
    print("Generator works outside of model.fit()!")

model_train.fit(
    dataset,
    epochs=2,
    steps_per_epoch=2
)

The snippet above produces the following error message when .fit() is called.:

2020-01-28 17:35:56.705549: W tensorflow/core/framework/op_kernel.cc:1639] Invalid argument: ValueError: Tensor("dense/kernel/Read/ReadVariableOp:0", shape=(16, 16), dtype=float32) must be from the same graph as Tensor("sequential/dense/Cast:0", shape=(4, 16), dtype=float32).
Traceback (most recent call last):

The code will run normally if the generator does not call the model_generator model. For example:

def generator():
    while True:
        data = np.random.random((batch_size, data_size))
        targets = np.random.random((batch_size, 1))
        yield (data, targets)

I believe that the fit call creates its own tensorflow graph which does not include the nodes needed for model_generator . Is there any way to use one model in the generator used for training another model like this? If so, how can I modify the above example to achieve this?

I don't think it's possible to have the graph connected between a tf.Dataset and a Keras model. Instead, what you need to do is find some way of creating a single model out of the two.

If the model_generator can be used directly as an input to your model_train , then the easiest way I can think of to do this is to create a Sequential model containing both models. Here's a simple example based on your snippet above. In this example, only model_train will be backpropogated through.

import numpy as np
import tensorflow as tf
from tensorflow_core.python.keras import Sequential, Model
from tensorflow_core.python.keras.layers import Dense, Input

batch_size = 4
data_size = 16

def generator():
    while True:
        data = np.random.random((batch_size, data_size))
        targets = np.random.random((batch_size, 1))
        yield (data, targets)

model_generator = Sequential([Dense(data_size)])

# Freeze weights in your generator so they don't get updated
for layer in model_generator.layers:
    layer.trainable = False

model_train = Sequential([Dense(data_size)])

# Create a model to train which calls both model_generator and model_train
model_fit = Sequential([model_generator, model_train])
model_fit.compile(optimizer='adam', loss='mae')
model_fit.build((batch_size, data_size))
model_fit.summary()

output_types = (tf.float64, tf.float64)
output_shapes = (tf.TensorShape((batch_size, data_size)),
                 tf.TensorShape((batch_size, 1)))
dataset = tf.data.Dataset.from_generator(
    generator,
    output_types=output_types,
    output_shapes=output_shapes,
)

model_fit.fit(
    dataset,
    epochs=2,
    steps_per_epoch=2
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM