I am trying to use the tf.Data API to accelerate my code and prevent GPU data starvation but there is one thing that is stopping me for being comfortable with it and it's the ability to use the same batch when calling the training op multiple times.
Suppose I have my dataset set up as
dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()
loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)
but now whenever I call train_op1
, iterator.get_next()
is called and so when calling train_op2
, I am training on the next batch.
From this question, I am aware that I can use a combination of flat_map
and repeat(n)
where n
is the number of times I wanna repeat the same batch but this n
would depend on the number of train_ops
that I call which I have to count manually. Also, I need these two train_ops
because they optimize different parts of my graph.
Thank you for your help!
Try the code below. It creates a copy of input and target so hopefully they will not change when you switch optimizer/loss_op. They are persistent between sess.run
calls as long as you do not pass is_new:True
flag.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def ds_train(batch_size, num_epochs):
ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
.batch(batch_size)
.repeat(num_epochs)
)
return ds
batch_size = 1
input_size = 1
num_epochs = 2
with tf.variable_scope("dataset"):
ds_t = ds_train(batch_size, num_epochs)
with tf.variable_scope("iterator"):
iterator_t = ds_t.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def next_item():
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
return x, y
inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")
def new_data(batch_size, input_size):
# run the data layer to generate a new batch
next_inputs, next_target = next_item()
next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
return tf.identity(inputs), tf.identity(target)
def old_data():
# just forward the existing batch
return inputs, target
next_inputs, next_target = next_item()
inputs, target = tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
sess.run(iterator_t.initializer)
while True:
try:
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
except tf.errors.OutOfRangeError:
print("End of training dataset.")
break
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.