简体   繁体   English

为什么这个TensorFlow代码在测试用例中表现不同?

[英]Why does this TensorFlow code behave differently when inside a test case?

I have a function ( foo below) which is behaving differently when it's run directly vs when it is run inside a tf.test.TestCase . 我有一个函数( foo如下),当它在tf.test.TestCase运行时直接运行时表现不同。

The code is supposed to create a dataset with elems [1..5] and shuffle it. 代码应该用elems [1..5]创建一个数据集并将其洗牌。 Then it repeats 3 times: create an iterator from the data and use that to print the 5 elements. 然后重复3次:从数据创建一个迭代器,并使用它来打印5个元素。

When run on its own it gives output where all the lists are shuffled eg: 当它自己运行时,它会输出所有列表被洗牌的输出,例如:

[4, 0, 3, 2, 1]
[0, 2, 1, 3, 4]
[2, 3, 4, 0, 1]

but when run inside a test case they are always the same, even between runs: 但是当在测试用例中运行时,它们总是相同的,即使在运行之间:

[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]

I imagine it's something to do with how test cases handle random seeds but I can't see anything about that in the TensorFlow docs. 我想这与测试用例如何处理随机种子有关,但我在TensorFlow文档中看不到任何相关内容。 Thanks for any help! 谢谢你的帮助!


Code: 码:

import tensorflow as tf

def foo():
    sess = tf.Session()
    dataset = tf.data.Dataset.range(5)
    dataset = dataset.shuffle(5, reshuffle_each_iteration=False)

    for _ in range(3):
        data_iter = dataset.make_one_shot_iterator()
        next_item = data_iter.get_next()
        with sess.as_default():
            data_new = [next_item.eval() for _ in range(5)]
        print(data_new)


class DatasetTest(tf.test.TestCase):
    def testDataset(self):
        foo()

if __name__ == '__main__':
    foo()
    tf.test.main()

I am running it with Python 3.6 and TensorFlow 1.4. 我使用Python 3.6和TensorFlow 1.4运行它。 No other modules should be needed. 不需要其他模块。

I think you are right; 我想你是对的; tf.test.TestCase is being setup to use fixed seed. tf.test.TestCase正在设置为使用固定种子。

class TensorFlowTestCase(googletest.TestCase):
# ...
def setUp(self):
  self._ClearCachedSession()
  random.seed(random_seed.DEFAULT_GRAPH_SEED)
  np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
  ops.reset_default_graph()
  ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED

and DEFAULT_GRAPH_SEED = 87654321 see this line in tensorflow/tensorflow/python/framework/random_seed.py . DEFAULT_GRAPH_SEED = 87654321tensorflow/tensorflow/python/framework/random_seed.py查看行。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 为什么绘图函数 plt.show() 在循环内部或外部时表现不同? - Why does the plotting function plt.show() behave differently when inside or outside a loop? 为什么这个argparse代码在Python 2和3之间表现不同? - Why does this argparse code behave differently between Python 2 and 3? Python 2-为什么“ with”在嵌入式C代码中表现不同? - python 2 - why does 'with' behave differently in embedded c code? 为什么在包装时sys.excepthook会有不同的行为? - Why does sys.excepthook behave differently when wrapped? 当类在函数中时,为什么类中的全局行为会有所不同? - why does global in a class behave differently when the class is within a function? 为什么RestrictedPython在与Python 3.6一起使用时表现不同? - Why does RestrictedPython behave differently when used with Python 3.6? 为什么这个上下文管理器与dict理解有不同的表现? - Why does this contextmanager behave differently with dict comprehensions? 为什么 numpy import 的行为不同? - Why does numpy import behave differently? 为什么 pyzmq 订阅者与 asyncio 的行为不同? - Why does pyzmq subscriber behave differently with asyncio? 为什么 groupby 操作的行为不同 - Why does groupby operations behave differently
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM