[英]Why does this TensorFlow code behave differently when inside a test case?
I have a function ( foo
below) which is behaving differently when it's run directly vs when it is run inside a tf.test.TestCase
. 我有一个函数(
foo
如下),当它在tf.test.TestCase
运行时直接运行时表现不同。
The code is supposed to create a dataset with elems [1..5] and shuffle it. 代码应该用elems [1..5]创建一个数据集并将其洗牌。 Then it repeats 3 times: create an iterator from the data and use that to print the 5 elements.
然后重复3次:从数据创建一个迭代器,并使用它来打印5个元素。
When run on its own it gives output where all the lists are shuffled eg: 当它自己运行时,它会输出所有列表被洗牌的输出,例如:
[4, 0, 3, 2, 1]
[0, 2, 1, 3, 4]
[2, 3, 4, 0, 1]
but when run inside a test case they are always the same, even between runs: 但是当在测试用例中运行时,它们总是相同的,即使在运行之间:
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
I imagine it's something to do with how test cases handle random seeds but I can't see anything about that in the TensorFlow docs. 我想这与测试用例如何处理随机种子有关,但我在TensorFlow文档中看不到任何相关内容。 Thanks for any help!
谢谢你的帮助!
import tensorflow as tf
def foo():
sess = tf.Session()
dataset = tf.data.Dataset.range(5)
dataset = dataset.shuffle(5, reshuffle_each_iteration=False)
for _ in range(3):
data_iter = dataset.make_one_shot_iterator()
next_item = data_iter.get_next()
with sess.as_default():
data_new = [next_item.eval() for _ in range(5)]
print(data_new)
class DatasetTest(tf.test.TestCase):
def testDataset(self):
foo()
if __name__ == '__main__':
foo()
tf.test.main()
I am running it with Python 3.6 and TensorFlow 1.4. 我使用Python 3.6和TensorFlow 1.4运行它。 No other modules should be needed.
不需要其他模块。
I think you are right; 我想你是对的;
tf.test.TestCase
is being setup to use fixed seed. tf.test.TestCase
正在设置为使用固定种子。
class TensorFlowTestCase(googletest.TestCase):
# ...
def setUp(self):
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
ops.reset_default_graph()
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
and DEFAULT_GRAPH_SEED = 87654321
see this line in tensorflow/tensorflow/python/framework/random_seed.py
. 和
DEFAULT_GRAPH_SEED = 87654321
在tensorflow/tensorflow/python/framework/random_seed.py
查看此行。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.