简体   繁体   English

如何在'input_fn'中使用张量流的迭代器'make_initializable_iterator'?

[英]How to use the iterator 'make_initializable_iterator' of tensorflow within a 'input_fn'?

I want to train my mode with a tf.estimator.Estimator and load my data by Dataset API.Because my data,for example 'mnist', is a array(tensor),so I try to load it with 'tf.data.Dataset.from_tensor_slices'.But I don't how to initialize 'make_initializable_iterator' within a 'input_fn'. 我想用tf.estimator.Estimator训练我的模式并通过Dataset API加载我的数据。因为我的数据(例如'mnist')是一个数组(张量),所以我尝试使用'tf.data'加载它。 Dataset.from_tensor_slices'。但是我不怎么在'input_fn'中初始化'make_initializable_iterator'。

If I can use 'make_one_shot_iterator' to train successfully, but it load slowly before training. 如果我可以使用“ make_one_shot_iterator”成功训练,但是在训练之前加载缓慢。 And 《 Higher-Level APIs in TensorFlow 》is a good example to 'make_initializable_iterator' within a 'input_fn',but it needs to return a 'iterator_initializer_hook' to other function from 'input_fn' . 而《 TensorFlow中的高级API 》是在'input_fn'中执行'make_initializable_iterator'的一个很好的例子,但是它需要将'iterator_initializer_hook'从'input_fn'返回到其他函数。 I want to know is there any other better or more elegant way? 我想知道还有其他更好或更优雅的方式吗?

    def input_fn():

    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset iterator
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    iterator = dataset.make_one_shot_iterator()
    next_example = iterator.get_next()
    # Set runhook to initialize iterator

    return next_example

In TensorFlow version 1.5 and later, the tf.estimator.Estimator will automatically create and initialize an initializable iterator when you return a tf.data.Dataset from your input_fn . 在TensorFlow 1.5版和更高版本中,当您从input_fn返回tf.data.Dataset时, tf.estimator.Estimator将自动创建并初始化可初始化的迭代器。 This enables you to write the following code, without having to worry about initialization or hooks: 这使您可以编写以下代码,而不必担心初始化或挂钩:

def input_fn():
    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset.
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    return dataset

Inside your code, add this: 在您的代码中,添加以下内容:

      self.hooks.append(utils_hooks.DatasetHook(iter))

In the run_loop.py, before the call into your fn, add this 在run_loop.py中,在调用您的fn之前,添加此代码

 for hook in dataset_hooks:
        sess.run(hook.iterator().initializer)

Then, it should be fine. 然后,应该没问题。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM