简体   繁体   中英

How to make tf.data.Dataset.from_generator yield batches with a custom generator

I want to use the tf.data API. My expected workflow looks like the following:

  • Input image is a 5D tensor with (batch_size, width, height, channels, frames)

  • First layer is a 3D convolution

I use the tf.data.from_generator function to create an iterator. Later I make a initializable iterator.

My code would look something like this:

def custom_gen():
   img = np.random.normal((width, height, channels, frames))
   yield(img, img) # I train an autoencoder, so the x == y`

dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()

sess = tf.Session()
sess.run(iter.get_next())

I would expect that iter.get_next() yielded me a 5D tensor with the batch size. However, I even tried to yield the batch size in my own custom_generator and it does not work. I face an error, when I want to initialize the dataset with my placeholder of the input shape (batch_size, width, height, channels, frames) .

The Dataset construction process in that example is ill-formed. It should be done in this order, as also established by the official guide on Importing Data :

  1. A base dataset creation function or static method should be called for establishing the original source of data (eg the static methods from_slice_tensors , from_generator , list_files , ...).
  2. At this point, transformations may be applied by chaining adapter methods (such as batch ).

Thus:

dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM