I am sub-classing tensorflow.keras.Model
to implement a certain model. Expected behavior:
And the code is:
class SomeModel(tensorflow.keras.Model):
# ......
def call(self, x, training=True):
# ......
return [aux1, aux2, net] if training else net
This is how i use it:
model=SomeModel(...)
model.compile(...,
loss=keras.losses.SparseCategoricalCrossentropy(),
loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])
And got:
AssertionError: in converted code:
ipython-input-33-862e679ab098:140 call *
`return [aux1, aux2, net] if training else net`
...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt
Then the problem is that the if
statement is converted into the calculation graph and this would of course cause the problem. I found the whole stack trace is long and useless so it's not included here.
So, is there any way to make TensorFlow generate different graph based on training
or not?
Which tensorflow version are you using? You can overwrite behaviour in the.fit, .predict and.evaluate methods in Tensorflow 2.2, which would generate different graphs for these methods (I assume) and potentially work for your use-case.
The problems with earlier versions is that subclassed models get created by tracing the call
method. This means Python conditionals become Tensorflow conditionals and face several limitations during graph creation and execution.
First, both branches (if-else) have to be defined, and regarding python collections (eg. lists), the branches have to have the same structure (eg. number of elements). You can read about the limitations and effects of Autograph here and here .
(Also, a conditional may not get evaluated at every run, if the condition is based on a Python variable and not a tensor.)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.