简体   繁体   中英

Tensorflow's random.truncated_normal returns different results with the same seed

The following lines are supposed to get the same result:

print (tf.random.truncated_normal(shape=[2],seed=1234))
print (tf.random.truncated_normal(shape=[2],seed=1234))

But I got:

tf.Tensor([-0.12297685 -0.76935077], shape=(2,), dtype=float32)
tf.Tensor([0.37034193 1.3367208 ], shape=(2,), dtype=float32)

Why?

Tensorflow has two types of seeds the global and the operational - this is also why you need to pass two numbers to stateless_truncated_normal as xdurch0 describes in his answer. Tensorflow combines these two seeds to generate a new one.

tf.random.truncated_normal(shape=[2],seed=1234) # global seed #1 & operational 1234 -> Seed A
tf.random.truncated_normal(shape=[2],seed=1234) # global seed #2 & operational 1234 -> Seed B

There are multiple ways to tackle your problem. Set the global seed as well beforehand twice. Work inside @tf.functions these kindof reset the global seed and have their own operational counters. Or use stateless_truncated_normal as written in the other answer.

As already linked, in the documentation it is described as well.

This seems to be intentional, see the docs here . Specifically the "Examples" section.

What you need is stateless_truncated_normal :

print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))
print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))

Gives me

tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)
tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)

Note: The seed needs to be two numbers here, I honestly don't know why (the docs don't say).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM