简体   繁体   English

查找失败的索引numpy.assert_almost_equal

[英]Find indexes that fail numpy.assert_almost_equal

I'm dealing with large 3D arrays like (110,80,817) and wanting to compare two arrays in some unit tests. 我正在处理像(110,80,817)这样的大型3D数组,并希望在某些单元测试中比较两个数组。 However, the default output from numpy.assert_almost_equal doesn't help me track down the errors very easily. 但是, numpy.assert_almost_equal的默认输出并不能帮助我很容易地找到错误。 For example: 例如:

>                   raise AssertionError(msg)
E                   AssertionError:
E                   Arrays are not almost equal to 7 decimals
E
E                   (mismatch 0.0314621119395%)
E                    x: array([[[ 0.,  0.,  0., ...,  0.,  0.,  0.],
E                           [ 0.,  0.,  0., ...,  0.,  0.,  0.],
E                           [ 0.,  0.,  0., ...,  0.,  0.,  0.],...
E                    y: array([[[ 0.,  0.,  0., ...,  0.,  0.,  0.],
E                           [ 0.,  0.,  0., ...,  0.,  0.,  0.],
E                           [ 0.,  0.,  0., ...,  0.,  0.,  0.],...

Is there a way to easily see which 3D indexes are failing this assertion? 有没有一种方法可以轻松查看哪些3D索引未通过该声明?

You can use np.isclose combined with np.where for this 您可以将np.isclosenp.where结合使用

idx = zip(*np.where(~np.isclose(a, b, atol=0, rtol=1e-7)))

Now idx will be a list of all the indices (x,y,z) where the assertion fails. 现在, idx将是断言失败的所有索引(x,y,z)的列表。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM