简体   繁体   中英

Manually Get Next Batch or Use Identical Batch with TensorFlow Data API

I am trying to use the tf.Data API to accelerate my code and prevent GPU data starvation but there is one thing that is stopping me for being comfortable with it and it's the ability to use the same batch when calling the training op multiple times.

Suppose I have my dataset set up as

dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()

loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)

but now whenever I call train_op1 , iterator.get_next() is called and so when calling train_op2 , I am training on the next batch.

From this question, I am aware that I can use a combination of flat_map and repeat(n) where n is the number of times I wanna repeat the same batch but this n would depend on the number of train_ops that I call which I have to count manually. Also, I need these two train_ops because they optimize different parts of my graph.

Thank you for your help!

Try the code below. It creates a copy of input and target so hopefully they will not change when you switch optimizer/loss_op. They are persistent between sess.run calls as long as you do not pass is_new:True flag.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf


def ds_train(batch_size, num_epochs):  
    ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
            .batch(batch_size)
            .repeat(num_epochs)        
            )
    return ds


batch_size = 1
input_size = 1
num_epochs = 2

with tf.variable_scope("dataset"):       
    ds_t = ds_train(batch_size, num_epochs)

with tf.variable_scope("iterator"):
    iterator_t = ds_t.make_initializable_iterator()
    iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
    iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
                                                iterator_t.output_types,
                                                iterator_t.output_shapes)

    def next_item():
        next_elem = iterator.get_next(name="next_element")
        x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
        return x, y        


inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")

def new_data(batch_size, input_size):
    # run the data layer to generate a new batch
    next_inputs, next_target = next_item()
    next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
    with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
        return tf.identity(inputs), tf.identity(target)

def old_data():
    # just forward the existing batch
    return inputs, target

next_inputs, next_target = next_item()

inputs, target =  tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
    handle_t = sess.run(iterator_t.string_handle())
    sess.run(iterator_t.initializer)
    while True:
        try:
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
        except tf.errors.OutOfRangeError:
            print("End of training dataset.")
            break        

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM