简体   繁体   中英

Is this function creating a new TensorFlow graph each time?

I am following this tutorial on how to use tf.scan and I wrote a minimal working example (see code below). But each time the function Model._step() is called, isn't it creating another copy of the computational graph? If not, why not?

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc...

class Model():
    def __init__(self):
        self._inputs = tf.placeholder(shape=[None], dtype=tf.float32)
        self._predictions = self._compute_predictions()

    def _step(self, old_state, new_input):
        # ---- In here I will write a much more complex graph ----
        return old_state + new_input

    def _compute_predictions(self):
        return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0))

    @property
    def predictions(self):
        return self._predictions

    @property
    def inputs(self):
        return self._inputs

def test(sess, model):
    sess.run(tf.global_variables_initializer())
    print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]}))

test(tf.Session(), Model())

I'm asking because this is of course a minimal example, in my case I will need a much more complex graph.

The Model._step() method will only be called once per Model object constructed. The tf.scan() function, like the tf.while_loop() function it wraps, will call their given function(s) only once to build a graph with a loop in it, and then the same graph will be used for each iteration of the loop.

(Note that if you construct many Model objects you will end up with the same number of copies of the graph as you have Model objects.)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM