简体   繁体   中英

TF2 - tf.function and class variables breaking

I'm using tensorflow 2.3 and I'm having problems with tf.function inside classes if I store the value in the class instead of returning it

For example

import tensorflow as tf

class Test:
    def __init__(self):
        self.count = tf.convert_to_tensor(0)
        
    @tf.function
    def incr(self):
        self.count += 1
        return self.count
        
t = Test()
count = t.incr()
count == t.count

creates the following error

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: add:0

When I check the values of count vs t.count I see

  • count: <tf.Tensor: shape=(), dtype=int32, numpy=1>
  • t.count: <tf.Tensor 'add:0' shape=() dtype=int32>

how can I fix this so t.count stores the same value?

You can use tf.Variable .

class Test:
        def __init__(self):
           # self.count = tf.convert_to_tensor(0)
             self.count = tf.Variable(0)
        @tf.function
        def incr(self):
            self.count.assign_add(1)
            return self.count
     
    t = Test()
count = t.incr()
print(count) #tf.Tensor(1, shape=(), dtype=int32)
count = t.count
print(count)#<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=1>

a tensor is immutable. When it is used inside a graph (which is a structure of tf.operation and tf.tensor), a tensor is not executed eagerly; so you can't get its value, and is referred to as a part of add operation inside the graph :t.count: <tf.Tensor 'add:0' shape=() dtype=int32>.

count = tf.constant(0)
print(count)# count is eager_tensor : tf.Tensor(0, shape=(), dtype=int32)
@tf.function
def incr():
        global count # we use count tensor inside graph contect
        count+=1
        return count
        
c = incr()
print(count)# now count is graph_tensor :Tensor("add:0", shape=(),dtype=int32)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM