繁体   English   中英

单元测试 Tensorflow 变换操作

[英]Unit test Tensorflow Transform operations

我想对包含 Tensorflow 变换操作的函数进行单元测试。 像这样的东西:

@pytest.mark.parametrize("inputs,expected_result",
    [(
        tf.linspace(start=0, stop=10, num=10),
        tf.linspace(start=0, stop=1, num=10),
    )]
)
def test_tft_scale(inputs, expected_result):
    out = tft.scale_to_0_1(inputs)
    assert tf.experimental.numpy.allclose(out, expected_result)

我收到与急切执行相关的错误:

RuntimeError: tf.placeholder() is not compatible with eager execution

我尝试通过放置tf.compat.v1.disable_eager_execution()来禁用急切执行,但我遇到了另一个错误:

tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

如何解决这个问题? 谢谢!

因此,您正在尝试将 Python 急切代码转换为与图形兼容的 TensorFlow 操作,您可以使用AutoGraph库执行此操作。 您可以通过简单地添加@tf.function装饰器并将其与您的@pytest.mark.parametrize()装饰器链接来使用它。 @tf.function AutoGraph使用 AutoGraph,您可以在此处阅读更多信息。

编辑:

我认为您最好像这样使用tf.test.TestCase ,删除使用tft.scale_to_0_1()因为您可以使用tensor = tensor / tf.norm(tensor)将张量缩放到 0 到 1 的范围. 您不需要启用tf.compat.v1.enable_eager_execution()

import tensorflow as tf

class TFTScaleTest(tf.test.TestCase):
    def setUp(self):
        super(TFTScaleTest, self).setUp()
        self.input = tf.linspace(start=0, stop=10, num=10)
        self.expected_result = tf.linspace(start=0, stop=1, num=10)

    def test_tft_scale(self):
        tensor = tf.cast(self.input, dtype=tf.float64)
        tensor = tensor / tf.norm(tensor)
        self.assertAllClose(tensor, self.expected_result)


if __name__ == "__main__":
    tf.test.main()

测试用例结果:

AssertionError: 
Not equal to tolerance rtol=1e-06, atol=1e-06
Mismatched value: a is different from b. 
not close where = (array([1, 2, 3, 4, 5, 6, 7, 8, 9]),)
not close lhs = [0.05923489 0.11846978 0.17770466 0.23693955 0.29617444 0.35540933
 0.41464421 0.4738791  0.53311399]
not close rhs = [0.11111111 0.22222222 0.33333333 0.44444444 0.55555556 0.66666667
 0.77777778 0.88888889 1.        ]
not close dif = [0.05187622 0.10375245 0.15562867 0.20750489 0.25938112 0.31125734
 0.36313356 0.41500979 0.46688601]
not close tol = [1.11111111e-06 1.22222222e-06 1.33333333e-06 1.44444444e-06
 1.55555556e-06 1.66666667e-06 1.77777778e-06 1.88888889e-06
 2.00000000e-06]
dtype = float64, shape = (10,)
Mismatched elements: 9 / 10 (90%)
Max absolute difference: 0.46688601
Max relative difference: 0.46688601
 x: array([0.      , 0.059235, 0.11847 , 0.177705, 0.23694 , 0.296174,
       0.355409, 0.414644, 0.473879, 0.533114])
 y: array([0.      , 0.111111, 0.222222, 0.333333, 0.444444, 0.555556,
       0.666667, 0.777778, 0.888889, 1.      ])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM