简体   繁体   中英

Using tf.dataset API in training cant get the whole data

I use tf.dataset to fetch images, labels, edges in training with GPU. But I find the dataset API cannot load all the data. I use codes:

def get_dataset(filenames, shuffle_buffer, repeat_times, batch_size):
    dataset = tf.data.TFRecordDataset([filenames])
    dataset = dataset.map(tfrecord_preprocess)
    if repeat_times is None:
        dataset = dataset.repeat()
    else:
        dataset = dataset.repeat(repeat_times)
    dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)
    return dataset

def tfrecord_preprocess(example):
    feature = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
               "label": tf.FixedLenFeature((), tf.string, default_value=""),
               "edge": tf.FixedLenFeature((), tf.string, default_value="")}
    parsed_feature = tf.parse_single_example(example, feature)

    image = tf.decode_raw(parsed_feature["image"], out_type=tf.uint8)
    label = tf.decode_raw(parsed_feature["label"], out_type=tf.uint8)
    edge = tf.decode_raw(parsed_feature["edge"], out_type=tf.uint8)

    image = tf.cast(tf.reshape(image, shape=[1, 128, 128]), tf.float32) 
    label = tf.cast(tf.reshape(label, shape=[1, 128, 128]), tf.float32)
    edge = tf.cast(tf.reshape(edge, shape=[128, 128]), tf.float32)
    return image, label, edge

I write a simple codes to test the API

dataset = get_dataset(filenames, shuffle_buffer, repeat_times, batchsize)
#shuffle=1000, repeat_times=2, batchsize=13
iter = dataset.make_one_shot_iterator
images, labels, edges = iter.get_next()
count = 0
with tf.Session() as sess:
    for _ in xrange(40):
        try:
            edges_value = sess.run(edges)
            count = count+len(edges_value)
            print count
        except tf.errors.OutofRangeError:
            break

the number of data is 260 so after repeat and batch, the epochs should be 40. It works.

However, when I use similar code for training, the total number of data is less than 260, only 140(through the var count). Does anyone know haw to solve this problem? Pls help me.

I use tensorflow-gpu 1.4

my training code is:

shuffle_buffer = params["shuffle_buffer"] #1000
repeat_times = params["repeat_times"] #1
batch_size = params["batch_size"] #26
num_classes = params["num_classes"] #2

dataset = model.get_dataset(filenames, shuffle_buffer, repeat_times, batch_size)
iterator = dataset.make_one_shot_iterator()
with tf.device('/gpu:1'): 
    global_step = tf.train.get_or_create_global_step()    
    learning_rate = tf.train.exponential_decay(params["learning_rate"], 
                                           global_step, 100, 0.99)    
    optimizer = tf.train.AdamOptimizer(learning_rate)

    images, labels, edges = iterator.get_next()
    _, probs = model.interence(features=images, training=True)
    loss, reg = model.get_loss(probs, labels, edges, num_classes)
    _, acc_mean, _ = model.get_acc(probs, labels)

    train_op = optimizer.minimize(loss, global_step=global_step)

    variables_average = tf.train.ExponentialMovingAverage(0.99, global_step)
    var_list = tf.trainable_variables(scope='.*(kernel|bias)')
    variables_average_op = variables_average.apply(var_list)    

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_all_op = tf.group(train_op, variables_average_op)

tf.summary.scalar("loss", loss)
tf.summary.scalar("reg", reg)
tf.summary.scalar("acc_mean", acc_mean)

merged = tf.summary.merge_all()
saver = tf.train.Saver(max_to_keep=5)

config = tf.ConfigProto(log_device_placement=True,
                        allow_soft_placement=True)
config.gpu_options.allow_growth = True
count = 0
with tf.Session(config=config) as sess:
    tf.global_variables_initializer().run() 
    writer = tf.summary.FileWriter('./train', sess.graph)
    for _ in xrange(10):
        try:
            edges_value = sess.run(edges)
            count = count+len(edges_value)
            _, step, summary = sess.run([train_all_op, global_step, merged])
            writer.add_summary(summary, step)
            if step % 5 == 0:
                loss_value = sess.run(loss)
                print loss_value
                acc_mean_value = sess.run(acc_mean)
                print acc_mean_value
                saver.save(sess, params["save_dir"], step)
        except tf.errors.OutOfRangeError:
            print "end of data"
            break
    print count
    print "the final step is %d" % step
    loss_value = sess.run(loss)
    print loss_value
    acc_mean_value = sess.run(acc_mean)
    print acc_mean_value
    saver.save(sess, params["save_dir"], step)    
    writer.close()

finally I got info in the terminal:

end of data
130
the final step is 5

to test the code I set the repeat times 1

But I use test codes:

def test():
    dataset = get_dataset("train_output.tfrecords", 1000, 1, 26)
    terator = dataset.make_one_shot_iterator()
    images, labels, edges = iterator.get_next()

    count = 0
    with tf.Session() as sess:
        for i in xrange(10):
            try:
                images_value, labels_value, edges_value = sess.run([images, labels, edges])
                count = count+len(edges_value)
            except tf.errors.OutOfRangeError:
                print "end of data"

        print count
        print i

test()

The terminal shows:

260
9

The problem is that the sess.run(edges) is causing this part of the graph to execute again: images, labels, edges = iterator.get_next() . Therefore, everytime you run it, you are consuming one iteration that is not counted in your counter.

To get the count of edges, keep a counter inside of the with tf.device('/gpu:1') block. You can even graph it on tensorboard using a tf.summary.scalar similar to how you do so with loss .

Declare a edges_count = tf.Variable(1, name='edges_count', trainable=False, dtype=tf.int32)

images, labels, edges = iterator.get_next()
edges_count_update_op = tf.assign_add(edges_count, len(edges))

Then add edges_count_update_op to your train_op group.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM