简体   繁体   English

如何找到 tf.Dataset 的 len()

[英]How to find the len() of a tf.Dataset

I have started using the tf.data.Dataset as a way to load data into keras models, as they appear to be much faster than keras' ImageDataGenerator and much more memory efficient than training on arrays.我已经开始使用tf.data.Dataset作为将数据加载到 keras 模型中的一种方式,因为它们似乎比 keras 的ImageDataGenerator快得多,并且比 ZA3CBC3F9D0CE2F2C15CZDEB1 上的训练效率更高。

One think I can't get my head around is that I can't seem to find a way to access the len() of the dataset.有人认为我无法理解的是,我似乎无法找到访问数据集len()的方法。 Keras' ImageDataGenerator has an attribute called n which I used to use for this purpose. Keras 的ImageDataGenerator有一个名为n的属性,我曾经将其用于此目的。 This makes my code very ugly, as I need to hard-code the length in various parts of the scripy (eg to find out how many iterations an epoch has).这使我的代码非常难看,因为我需要对 scripy 各个部分的长度进行硬编码(例如,找出一个 epoch 有多少次迭代)。

Any ideas I can work around this issue?有什么想法可以解决这个问题吗?

An example script:一个示例脚本:

# Generator
def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    return ds


model = ...  # create a tf.keras model

batch_size 256
gen = make_mnist_train_generator(batch_size)

# Training
model.fit(gen, epochs=50, steps_per_epoch=60000//batch_size+1)  # Hard coded size of generator

tl;dr tl;博士

Unfortunately tf.data.Dataset is a generator and there is no inherent way of finding its size.不幸的是tf.data.Dataset是一个生成器,并没有找到它的大小的固有方法。

But...但...

Generally speaking, when you use .from_tensor_slices() you have a way of knowing its size by the argument you add in this method, in your case x_train .一般来说,当您使用.from_tensor_slices()时,您可以通过在此方法中添加的参数(在您的情况下x_train )了解其大小。 Your only issue is that you are creating it inside a function.您唯一的问题是您在 function 中创建它。

A neat hack you can do to bypass this issue is to add a __len__ attribute on your own: The easiest way I've found that you can do this is:你可以做一个巧妙的技巧来绕过这个问题,就是自己添加一个__len__属性:我发现你可以做到这一点的最简单方法是:

ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

In your case it would look something like this:在您的情况下,它看起来像这样:

def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

    return ds


gen = make_mnist_train_generator(batch_size)

model.fit(gen, epochs=50, steps_per_epoch=len(gen)//batch_size+1)  # Hard coded size of generator

Why do this?为什么要这样做?

I've done this in the past and its surprisingly useful.我过去做过这个,它非常有用。 There are many reasons why you'd want your generator to have a len() .您希望生成器具有len()的原因有很多。 Some examples are:一些例子是:

  • if you want to have the generator in a separate module and import it如果您想将生成器放在单独的模块中并导入它
  • if the generator is meant to be used by someone else that doesn't know what data was used to create it如果生成器打算由不知道使用什么数据创建它的其他人使用

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM