简体   繁体   中英

TF-Keras - Dataset.from_generator for multi-input functional API model

I have a generator that yields three variables. The first two variables are the two inputs to a two-input Keras model (functional API). I am using TF-Dataset to feed my model. Code is as follows:


train_dataset = tf.data.Dataset.from_generator(generator=make_generator_train,
                                                   args=[train_x_paths, train_y_int],
                                                   output_types=(tf.tuple((tf.float16, tf.float16)), tf.int8),
                                                   output_shapes=(tf.TensorShape([2]),
                                                                  tf.TensorShape([1]))).batch(batch_size=batch_size)

I get a TypeError with this:

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.tensor_shape.TensorShape'>.

Try it like this:

train_dataset = tf.data.Dataset.from_generator(
        generator=make_generator_train,
       args=[train_x_paths, train_y_int],
       output_types=(tf.float16, tf.int8)
).batch(batch_size=batch_size)

Most of the time you don't need to specify the output_shapes. It is decided in the runtime. Additionally, you only need to specify the overall dtype of the output tensor in the output_types. Not a dtype for each individual tensor dimension.

Solution: the generator should yield a dictionary for the inputs, and the output as is.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM