简体   繁体   中英

How to cache and iterate through a Dataset of unknown size?

While adding the .cache() step to my dataset pipeline, successives training epochs still download the data from the network storage.

I have a dataset on a network storage. I want to cache it, but not to repeat it: a training epoch must run through the whole dataset. Here is my dataset building pipeline:

return tf.data.Dataset.list_files(
        file_pattern
    ).interleave(
        tf.data.TFRecordDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).shuffle(
        buffer_size=2048
    ).batch(
        batch_size=2048,
        drop_remainder=True,
    ).cache(
    ).map(
        map_func=_parse_example_batch,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).prefetch(
        buffer_size=32
    )

If I use it as is, the dataset is downloaded at each epoch. To avoid this, I have to add the .repeat() step to the pipeline and use the steps_per_epoch keyword of the model.fit function. However, I do not know the size of the complete dataset and thus I cannot pass the right steps_per_epoch value.

What is the right way to cache and use dataset of unknown size?

Thanks.


Edit

While reading some TF code, I (re)discovered the make_initializable_iterator . It seems that it is what I am looking for, that is to say iterate multiple times through the same dataset (taking advantage of the cache after the first iteration). However, this is deprecated and no longer part of the main API in TF2.

Updating instruction is to manually iterate over the Dataset with for ... in dataset . Is it not what is done by the keras.Model.fit function? Have I to write the training loop manually to get cache advantages?

Kind.

In TF2.0, you do not need .repeat() . By

successives training epochs still download the data from the network storage.

I think you got confused with the message filling up shuffle buffer . This happens before every epoch if you are using shuffle() function. Maybe try without shuffle() , just to see the difference. Also, I would suggest you to use cache() after map() and before batch() .

EDIT

filling up shuffle buffer

is a message you get when using shuffle function. You can still shuffle() the dataset after using cache() . Look here Also, if I understood it correctly you are feeding the resulted dataset from map() to your model for training, then you should cache() this dataset not the other one because training will be done on this. For counting the number of elements in your dataset you can use following code

num_elements = 0
for element in dataset: # tf.dataset type
  num_elements += 1
print ('Total number of elements in the file: ',num_elements)

Now, by diving this num_elements with your batch_size you would get steps_per_epoch

Good news! Final v2.0.0 release fix this behavior.

Here is a code snippet to highlight the different behaviors.

import time

import tensorflow as tf
import tensorflow.keras as keras

# Simple layer that just print its inputs
class Print(keras.layers.Layer):

       def compute_output_signature(self, input_signature):
              return input_signature

       def call(self, inputs, **kwargs):
              tf.print(inputs)
              return inputs

# Generator returning incremented values each time it is re-initialized
generator_list = [0]
def generator():
       v = generator_list[-1]
       generator_list.append(v+1)
       tf.print("Generating samples with value {}".format(v))
       time.sleep(2)
       for i in range(2):
              yield (tf.constant([v]), tf.constant(v))


def main():
       model_input = keras.layers.Input(shape=(1,))
       model_output = Print()(model_input)
       model = keras.Model(inputs=model_input, outputs=model_output)
       model.compile("adam", loss="mae")

       ds = tf.data.Dataset.from_generator(
              generator, (tf.int64, tf.int64), ([1], [])
       )
       cached_ds = ds.cache()

       tf.print("Fit")
       model.fit(
              cached_ds,
              epochs=3,
              verbose=2
       )

       tf.print("For ... in ...")
       for i in range(3):
              for x, y in cached_ds:
                     model(x)

if __name__ == '__main__':
    main()

With tensorflow 2.0.0-b1 (used on Google AI Platform), here is the output:

Fit
Epoch 1/3
Generating samples with value 0
# sleep 2s
2019-10-03 15:45:32.718522: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1483] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU.  To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
[[0]]
[[0]]
2/2 - 2s - loss: 0.0000e+00
Generating samples with value 1
# sleep 2s
Epoch 2/3
[[1]]
[[1]]
2/2 - 2s - loss: 0.0000e+00
Epoch 3/3
2019-10-03 15:45:34.774195: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Generating samples with value 2
# sleep 2s
[[2]]
[[2]]
2019-10-03 15:45:36.782046: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2/2 - 2s - loss: 0.0000e+00
For ... in ...
Generating samples with value 3
# sleep 2s
[3]
[3]
Generating samples with value 4
# sleep 2s
[4]
[4]
Generating samples with value 5
# sleep 2s
[5]
[5]

You can see, that the value of the tensor is incremented for each epoch, and the sleep instruction is executed each time. Moreover, we get the warning about truncated iterator...

Now, with tensorflow 2.0.0:

Fit
Epoch 1/3
WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
Generating samples with value 0
# sleep 2s
[[0]]
[[0]]
2019-10-03 15:49:59.587796: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
2/2 - 2s - loss: 0.0000e+00
Epoch 2/3
[[0]]
[[0]]
2019-10-03 15:49:59.598144: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
2/2 - 0s - loss: 0.0000e+00
Epoch 3/3
[[0]]
[[0]]
2019-10-03 15:49:59.605260: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
For ... in ...
2/2 - 0s - loss: 0.0000e+00
[0]
[0]
[0]
[0]
[0]
[0]

And 'Voila'! The generator function is executed only once, with no more sleeps and always the same value of the tensor. I just have some warnings about end of sequence, but I can support it!

Kind.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM