简体   繁体   中英

Why does this TensorFlow code behave differently when inside a test case?

I have a function ( foo below) which is behaving differently when it's run directly vs when it is run inside a tf.test.TestCase .

The code is supposed to create a dataset with elems [1..5] and shuffle it. Then it repeats 3 times: create an iterator from the data and use that to print the 5 elements.

When run on its own it gives output where all the lists are shuffled eg:

[4, 0, 3, 2, 1]
[0, 2, 1, 3, 4]
[2, 3, 4, 0, 1]

but when run inside a test case they are always the same, even between runs:

[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]

I imagine it's something to do with how test cases handle random seeds but I can't see anything about that in the TensorFlow docs. Thanks for any help!


Code:

import tensorflow as tf

def foo():
    sess = tf.Session()
    dataset = tf.data.Dataset.range(5)
    dataset = dataset.shuffle(5, reshuffle_each_iteration=False)

    for _ in range(3):
        data_iter = dataset.make_one_shot_iterator()
        next_item = data_iter.get_next()
        with sess.as_default():
            data_new = [next_item.eval() for _ in range(5)]
        print(data_new)


class DatasetTest(tf.test.TestCase):
    def testDataset(self):
        foo()

if __name__ == '__main__':
    foo()
    tf.test.main()

I am running it with Python 3.6 and TensorFlow 1.4. No other modules should be needed.

I think you are right; tf.test.TestCase is being setup to use fixed seed.

class TensorFlowTestCase(googletest.TestCase):
# ...
def setUp(self):
  self._ClearCachedSession()
  random.seed(random_seed.DEFAULT_GRAPH_SEED)
  np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
  ops.reset_default_graph()
  ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED

and DEFAULT_GRAPH_SEED = 87654321 see this line in tensorflow/tensorflow/python/framework/random_seed.py .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM