繁体   English   中英

tf.function 输入参数

[英]tf.function input parameters

我在tensorflow 2中写了一个函数,用tf.keras写了一个模型。 该函数定义如下:

@tf.function
def mask_output(input_tensor,mask):
    if tf.math.count_nonzero(mask) > 0:
        output_tensor = tf.multiply(input_tensor, mask)
    else:
        output_tensor = input_tensor
    return output_tensor

我给它的两个参数是模型中的张量。 但是,当我定义模型并在模型定义中调用该函数时,它说:

{_SymbolicException}急切执行函数的输入不能是 Keras 符号张量,但已找到

[<tf.Tensor 'a_dense2/Identity:0' shape=(None, 12, 5) dtype=float32>, <tf.Tensor 'a_mask_input:0' shape=(None, 12, 5) dtype=float32>]

如何解决? 为什么我不能用两个 keras 张量输入调用那个函数?

如果在 Eager 模式下运行,tensorflow 操作将检查输入是否为tensorflow.python.framework.ops.EagerTensor类型,并且 keras ops 实现为 DAG。 因此,如果tensorflow.python.framework.ops.Tensor模式的输入是tensorflow.python.framework.ops.Tensor ,则会引发错误。

您可以通过使用tf.config.experimental_run_functions_eagerly(True)明确告诉 tensorflow 在 keras 的热切模式下运行,将输入类型更改为 EagerTensor。 添加此语句应该可以解决您的问题。

例如,这个程序抛出你面临的错误——

重现错误的代码-

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import layers, losses, models

def get_loss_fcn(w):
    def loss_fcn(y_true, y_pred):
        loss = w * losses.mse(y_true, y_pred)
        return loss
    return loss_fcn

data_x = np.random.rand(5, 4, 1)
data_w = np.random.rand(5, 4)
data_y = np.random.rand(5, 4, 1)

x = layers.Input([4, 1])
w = layers.Input([4])
y = layers.Activation('tanh')(x)
model = models.Model(inputs=[x, w], outputs=y)
loss = get_loss_fcn(model.input[1])

model.compile(loss=loss)
model.fit((data_x, data_w), data_y)

输出 -

2.2.0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: input_8:0

During handling of the above exception, another exception occurred:

_SymbolicException                        Traceback (most recent call last)
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     72       raise core._SymbolicException(
     73           "Inputs to eager execution function cannot be Keras symbolic "
---> 74           "tensors, but found {}".format(keras_symbolic_tensors))
     75     raise e
     76   # pylint: enable=protected-access

_SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'input_8:0' shape=(None, 4) dtype=float32>]

解决方案 -在程序顶部添加此tf.config.experimental_run_functions_eagerly(True)运行程序成功。 同时在程序顶部添加tf.compat.v1.disable_eager_execution()以禁用急切执行也可以成功运行程序。

固定码——

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import layers, losses, models

tf.config.experimental_run_functions_eagerly(True)

def get_loss_fcn(w):
    def loss_fcn(y_true, y_pred):
        loss = w * losses.mse(y_true, y_pred)
        return loss
    return loss_fcn

data_x = np.random.rand(5, 4, 1)
data_w = np.random.rand(5, 4)
data_y = np.random.rand(5, 4, 1)

x = layers.Input([4, 1])
w = layers.Input([4])
y = layers.Activation('tanh')(x)
model = models.Model(inputs=[x, w], outputs=y)
loss = get_loss_fcn(model.input[1])

model.compile(loss=loss)
model.fit((data_x, data_w), data_y)

print('Done.')

输出 -

2.2.0
1/1 [==============================] - 0s 1ms/step - loss: 0.0000e+00
Done.

希望这能回答你的问题。 快乐学习。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM