简体   繁体   English

Tensorflow & Keras:不同的 output 形状取决于训练或推断

[英]Tensorflow & Keras: Different output shape depends on training or inferring

I am sub-classing tensorflow.keras.Model to implement a certain model.我对tensorflow.keras.Model进行子类化以实现某个 Z20F35E630DAF399DBFA4C3F68 Expected behavior:预期行为:

  1. Training (fitting) time: returns a list of tensors including the final output and auxiliary output;训练(拟合)时间:返回张量列表,包括最终的 output 和辅助 output;
  2. Inferring (predicting) time: returns a single output tensor.推断(预测)时间:返回单个 output 张量。

And the code is:代码是:

class SomeModel(tensorflow.keras.Model):
    # ......
    def call(self, x, training=True):
        # ......
        return [aux1, aux2, net] if training else net

This is how i use it:这就是我使用它的方式:

model=SomeModel(...)
model.compile(...,
    loss=keras.losses.SparseCategoricalCrossentropy(),
    loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])

And got:并得到:

AssertionError: in converted code: AssertionError:在转换后的代码中:

ipython-input-33-862e679ab098:140 call * ipython-input-33-862e679ab098:140 调用 *

 `return [aux1, aux2, net] if training else net`

...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt ...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt

Then the problem is that the if statement is converted into the calculation graph and this would of course cause the problem.那么问题是if语句被转换为计算图,这当然会导致问题。 I found the whole stack trace is long and useless so it's not included here.我发现整个堆栈跟踪很长而且没用,所以这里不包括在内。

So, is there any way to make TensorFlow generate different graph based on training or not?那么,有没有办法让 TensorFlow 根据training生成不同的图?

Which tensorflow version are you using?您使用的是哪个 tensorflow 版本? You can overwrite behaviour in the.fit, .predict and.evaluate methods in Tensorflow 2.2, which would generate different graphs for these methods (I assume) and potentially work for your use-case.您可以覆盖 Tensorflow 2.2 中的 .fit、.predict 和 .evaluate 方法中的行为,这将为这些方法生成不同的图表(我假设)并可能适用于您的用例。

The problems with earlier versions is that subclassed models get created by tracing the call method.早期版本的问题是子类模型是通过跟踪call方法创建的。 This means Python conditionals become Tensorflow conditionals and face several limitations during graph creation and execution.这意味着 Python 条件变为 Tensorflow 条件,并在图创建和执行期间面临一些限制。
First, both branches (if-else) have to be defined, and regarding python collections (eg. lists), the branches have to have the same structure (eg. number of elements).首先,必须定义两个分支(if-else),并且关于 python collections(例如列表),分支必须具有相同的结构(例如元素数量)。 You can read about the limitations and effects of Autograph here and here .您可以在此处此处阅读有关 Autograph 的限制和影响的信息。

(Also, a conditional may not get evaluated at every run, if the condition is based on a Python variable and not a tensor.) (此外,如果条件基于 Python 变量而不是张量,则可能不会在每次运行时评估条件。)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM