[英]Unit test Tensorflow Transform operations
我想对包含 Tensorflow 变换操作的函数进行单元测试。 像这样的东西:
@pytest.mark.parametrize("inputs,expected_result",
[(
tf.linspace(start=0, stop=10, num=10),
tf.linspace(start=0, stop=1, num=10),
)]
)
def test_tft_scale(inputs, expected_result):
out = tft.scale_to_0_1(inputs)
assert tf.experimental.numpy.allclose(out, expected_result)
我收到与急切执行相关的错误:
RuntimeError: tf.placeholder() is not compatible with eager execution
我尝试通过放置tf.compat.v1.disable_eager_execution()
来禁用急切执行,但我遇到了另一个错误:
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
如何解决这个问题? 谢谢!
因此,您正在尝试将 Python 急切代码转换为与图形兼容的 TensorFlow 操作,您可以使用AutoGraph
库执行此操作。 您可以通过简单地添加@tf.function
装饰器并将其与您的@pytest.mark.parametrize()
装饰器链接来使用它。 @tf.function
AutoGraph
使用 AutoGraph,您可以在此处阅读更多信息。
我认为您最好像这样使用tf.test.TestCase
,删除使用tft.scale_to_0_1()
因为您可以使用tensor = tensor / tf.norm(tensor)
将张量缩放到 0 到 1 的范围. 您不需要启用tf.compat.v1.enable_eager_execution()
。
import tensorflow as tf
class TFTScaleTest(tf.test.TestCase):
def setUp(self):
super(TFTScaleTest, self).setUp()
self.input = tf.linspace(start=0, stop=10, num=10)
self.expected_result = tf.linspace(start=0, stop=1, num=10)
def test_tft_scale(self):
tensor = tf.cast(self.input, dtype=tf.float64)
tensor = tensor / tf.norm(tensor)
self.assertAllClose(tensor, self.expected_result)
if __name__ == "__main__":
tf.test.main()
AssertionError:
Not equal to tolerance rtol=1e-06, atol=1e-06
Mismatched value: a is different from b.
not close where = (array([1, 2, 3, 4, 5, 6, 7, 8, 9]),)
not close lhs = [0.05923489 0.11846978 0.17770466 0.23693955 0.29617444 0.35540933
0.41464421 0.4738791 0.53311399]
not close rhs = [0.11111111 0.22222222 0.33333333 0.44444444 0.55555556 0.66666667
0.77777778 0.88888889 1. ]
not close dif = [0.05187622 0.10375245 0.15562867 0.20750489 0.25938112 0.31125734
0.36313356 0.41500979 0.46688601]
not close tol = [1.11111111e-06 1.22222222e-06 1.33333333e-06 1.44444444e-06
1.55555556e-06 1.66666667e-06 1.77777778e-06 1.88888889e-06
2.00000000e-06]
dtype = float64, shape = (10,)
Mismatched elements: 9 / 10 (90%)
Max absolute difference: 0.46688601
Max relative difference: 0.46688601
x: array([0. , 0.059235, 0.11847 , 0.177705, 0.23694 , 0.296174,
0.355409, 0.414644, 0.473879, 0.533114])
y: array([0. , 0.111111, 0.222222, 0.333333, 0.444444, 0.555556,
0.666667, 0.777778, 0.888889, 1. ])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.