简体   繁体   中英

Tensorflow datasets usage

I'm trying to create a simple dataset of string-label pairs, and can't get tensorflow to connect these pairs correctly

I'm trying to use a Dataset.from_tensor_slices initializer and dataset.make_one_shot_iterator iterator:

import tensorflow as tf

strings = [
    'aaaa',
    'asdf'
]
labels = [1,0]

sess = tf.Session()
tf.global_variables_initializer()

dataset = tf.data.Dataset.from_tensor_slices((strings, labels))
dataset = dataset.repeat()
dataset = dataset.shuffle(512)
iterator = dataset.make_one_shot_iterator()


x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

At the end, I'm expecting the output to be '1' for 'aaaa' and '0' for 'asdf', but repeatedly getting something random:

aaaa 0
asdf 0
aaaa 1
asdf 1
aaaa 1
aaaa 0
asdf 1
aaaa 1

Please suggest what might be wrong in my code

By the way, if I remove shuffling, I won't be able to get to another string, the iterator will only output:

aaaa 0
aaaa 0 
aaaa 0
...

with wrong labels... Does anyone know the reason begind that?

This is how I use it.

next = iterator.get_next()
# print(next)

with  tf.Session() as sess:
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))


(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'aaaa', 1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM