[英]Numpy arctan2 of multidimensional array
我正在嘗試整理一些編寫為采用單個float
值的代碼,因此使用1D(最終是2D) numpy.arrays
作為輸入可以正常工作。
簡化為一個最小的示例,該函數看起來像這樣(該示例沒有做任何有用的事情,但是如果刪除了do_math
和do_some_more_math
,它將完全產生所描述的行為):
def do_complicated_math(r, g, b):
rgb = numpy.array([r, g, b])
# Math! No change in array shape. To run example just comment out.
rgb = do_math(rgb)
m_2 = numpy.array([[rgb[0], 0, 0], [0, rgb[1], 0], [0, 0, rgb[2]]])
# Get additional matrices needed for transformation.
# These are actually predefined 3x3 float arrays
m_1 = numpy.ones((3, 3))
m_3 = numpy.ones((3, 3))
# Transform the rgb array
rgb_transformed = m_1.dot(m_2).dot(m_3).dot(rgb)
# More math! No change in array shape. To run example just comment out.
rgb_transformed = do_some_more_math(rgb_transformed)
# Almost done just one more thing...
return numpy.arctan2(rgb_transformed, rgb_transformed)
# Works fine
do_complicated_math(1, 1, 1)
# Fails
x = numpy.ones(6)
do_complicated_math(x, x, x)
只要r
, g
和b
是單個數字,此函數就可以正常工作,但是,如果將它們指定為numpy.array
(例如,為了一次轉換多個rgb值),則numpy.arctan2
引發以下異常:
Traceback (most recent call last):
(...) line 32, in do_complicated_math
numpy.arctan2(rgb_transformed, rgb_transformed)
AttributeError: 'numpy.ndarray' object has no attribute 'arctan2'
關於這試圖告訴我的內容,我沒有找到任何明確的答案。 arctan2
似乎可以很好地用於多維數組,如下所示:
numpy.arctan2(numpy.ones((3,4,5)), numpy.ones((3,4,5)))
因此,我認為問題一定在於如何創建m_2
或如何傳播m_1
, m_2
, m_3
和rgb
的乘法,但是我似乎無法弄清楚它在哪里中斷。
問題是, rgb_transformed
不再是一個標准numpy的陣列,當您將它傳遞給arctan2
,它已成為一個對象數組:
print rgb_transformed
"""[[array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])]
[array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])]
[array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])
array([ 9., 9., 9., 9., 9., 9.])]]"""
print rgb_transformed.shape
#(3, 6)
print rgb_transformed.dtype
#object
所以這個問題比我想象的簡單:
這行:
m_2 = numpy.array([[rgb[0], 0, 0], [0, rgb[1], 0], [0, 0, rgb[2]]])
print m_2
#array([[array([ 1., 1., 1., 1., 1., 1.]), 0, 0],
# [0, array([ 1., 1., 1., 1., 1., 1.]), 0],
# [0, 0, array([ 1., 1., 1., 1., 1., 1.])]], dtype=object)
在這里創建對象數組,遍歷其余代碼。
編輯
要解決此問題,您可能需要稍微不同地廣播陣列。 基本上更改外部尺寸以反映不斷變化的rgb
值。 免責聲明:在您的問題中,我沒有很好的方法來驗證此結果,因此請謹慎對待輸出。
import numpy as np
def do_complicated_math(r, g, b):
rgb = np.array([r, g, b])
# create a transposed version of the m_2 array
m_2 = np.zeros((r.size,3,3))
for ii,ar in enumerate(rgb):
m_2[:,ii][:,ii][:] = ar
m_1 = np.ones((3, 3))
m_3 = np.ones((3, 3))
rgb_transformed = m_1.dot(m_2).dot(m_3).dot(rgb)
print rgb_transformed
return np.arctan2(rgb_transformed, rgb_transformed)
x = np.ones(6)
do_complicated_math(x, x, x)
r = np.array([0.2,0.3,0.1])
g = np.array([1.0,1.0,0.2])
b = np.array([0.3,0.3,0.3])
do_complicated_math(r, g, b)
這僅適用於作為輸入的數組,但是添加對單個值作為輸入的處理應該很簡單。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.