I am following this tutorial on how to use tf.scan
and I wrote a minimal working example (see code below). But each time the function Model._step()
is called, isn't it creating another copy of the computational graph? If not, why not?
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc...
class Model():
def __init__(self):
self._inputs = tf.placeholder(shape=[None], dtype=tf.float32)
self._predictions = self._compute_predictions()
def _step(self, old_state, new_input):
# ---- In here I will write a much more complex graph ----
return old_state + new_input
def _compute_predictions(self):
return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0))
@property
def predictions(self):
return self._predictions
@property
def inputs(self):
return self._inputs
def test(sess, model):
sess.run(tf.global_variables_initializer())
print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]}))
test(tf.Session(), Model())
I'm asking because this is of course a minimal example, in my case I will need a much more complex graph.
The Model._step()
method will only be called once per Model
object constructed. The tf.scan()
function, like the tf.while_loop()
function it wraps, will call their given function(s) only once to build a graph with a loop in it, and then the same graph will be used for each iteration of the loop.
(Note that if you construct many Model
objects you will end up with the same number of copies of the graph as you have Model
objects.)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.