简体   繁体   中英

Tensorflow & Keras: Different output shape depends on training or inferring

I am sub-classing tensorflow.keras.Model to implement a certain model. Expected behavior:

  1. Training (fitting) time: returns a list of tensors including the final output and auxiliary output;
  2. Inferring (predicting) time: returns a single output tensor.

And the code is:

class SomeModel(tensorflow.keras.Model):
    # ......
    def call(self, x, training=True):
        # ......
        return [aux1, aux2, net] if training else net

This is how i use it:

model=SomeModel(...)
model.compile(...,
    loss=keras.losses.SparseCategoricalCrossentropy(),
    loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])

And got:

AssertionError: in converted code:

ipython-input-33-862e679ab098:140 call *

 `return [aux1, aux2, net] if training else net`

...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt

Then the problem is that the if statement is converted into the calculation graph and this would of course cause the problem. I found the whole stack trace is long and useless so it's not included here.

So, is there any way to make TensorFlow generate different graph based on training or not?

Which tensorflow version are you using? You can overwrite behaviour in the.fit, .predict and.evaluate methods in Tensorflow 2.2, which would generate different graphs for these methods (I assume) and potentially work for your use-case.

The problems with earlier versions is that subclassed models get created by tracing the call method. This means Python conditionals become Tensorflow conditionals and face several limitations during graph creation and execution.
First, both branches (if-else) have to be defined, and regarding python collections (eg. lists), the branches have to have the same structure (eg. number of elements). You can read about the limitations and effects of Autograph here and here .

(Also, a conditional may not get evaluated at every run, if the condition is based on a Python variable and not a tensor.)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM