簡體   English   中英

當函數包含條件時,使用Numpy將函數應用於數組

[英]Applying a function to an array using Numpy when the function contains a condition

當函數包含條件時,我很難將函數應用於數組。 我的解決方法效率低下,正在尋找一種高效(快速)的方法。 在一個簡單的示例中:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

現在,如果我運行:

result = fun(pts, pts)

然后我得到錯誤

ValueError:具有多個元素的數組的真值不明確。 使用a.any()或a.all()

if x > y行處引發。 我的效率不高的解決方法是:給出正確的結果但太慢了:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

以更好(更重要的是,更快)的方式獲得此效果的最佳方法是什么?

當函數包含條件時,我很難將函數應用於數組。 我的解決方法效率低下,正在尋找一種高效(快速)的方法。 在一個簡單的示例中:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

現在,如果我運行:

result = fun(pts, pts)

然后我得到了錯誤

ValueError:具有多個元素的數組的真值不明確。 使用a.any()或a.all()

if x > y行處引發。 我的效率不高的解決方法是:給出正確的結果但太慢了:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

以更好(更重要的是,更快)的方式獲得此效果的最佳方法是什么?

編輯 :使用

def fun(x, y):
    if x > y:
        return 0
    else:
        return 1
x = np.array(range(10))
y = np.array(range(10))
xv,yv = np.meshgrid(x,y)
result = fun(xv, yv)  

仍然引發相同的ValueError

該錯誤非常明顯-假設您有

x = np.array([1,2])
y = np.array([2,1])

這樣

(x>y) == np.array([0,1])

if np.array([0,1])語句的結果應該是什么? 是真的還是假的? numpy告訴您這是模棱兩可的。 運用

(x>y).all()

要么

(x>y).any()

是明確的,因此numpy可以為您提供解決方案-任何一個單元對都滿足條件,或者全部滿足-兩者都是明確的真實值。 您必須自己定義向量x大於向量y的含義。

用於對所有xy對進行操作的numpy解決方案,以使x[i]>y[j]使用網格網格生成所有對:

>>> import numpy as np
>>> x=np.array(range(10))
>>> y=np.array(range(10))
>>> xv,yv=np.meshgrid(x,y)
>>> xv[xv>yv]
array([1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8,
       9, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 6, 7, 8, 9, 7, 8, 9, 8, 9, 9])
>>> yv[xv>yv]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
       2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8])

發送xvyvfun ,或者在函數中創建網格yv ,這取決於更合理的選擇。 這將生成所有對xi,yj ,使得xi>yj 如果您想要實際的索引,只需返回xv>yv ,其中每個單元格ij對應於x[i]y[j] 在您的情況下:

def fun(x, y):
    xv,yv=np.meshgrid(x,y)
    return xv>yv

如果x[i]>y[j]則返回fun(x,y)[i][j]為True或否則為False的x[i]>y[j] 另外

return  np.where(xv>yv)

將返回兩個成對的索引對的數組的元組,這樣

for i,j in fun(x,y):

也會保證x[i]>y[j]

In [253]: x = np.random.randint(0,10,5)
In [254]: y = np.random.randint(0,10,5)
In [255]: x
Out[255]: array([3, 2, 2, 2, 5])
In [256]: y
Out[256]: array([2, 6, 7, 6, 5])
In [257]: x>y
Out[257]: array([ True, False, False, False, False])
In [258]: np.where(x>y,0,1)
Out[258]: array([0, 1, 1, 1, 1])

為了與這兩個1d數組進行笛卡爾比較,請重塑一個數組,以便可以使用broadcasting

In [259]: x[:,None]>y
Out[259]: 
array([[ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False]])
In [260]: np.where(x[:,None]>y,0,1)
Out[260]: 
array([[0, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [0, 1, 1, 1, 1]])

您的函數(帶有if僅適用於標量輸入。 如果給定數組,則a>b會生成一個布爾數組,該布爾數組不能在if語句中使用。 您的迭代有效,因為它傳遞了標量值。 對於某些最好的復雜函數,您可以做到( np.vectorize可以使迭代更簡單,但不能更快)。

我的答案是看一下數組比較,然后從中得出答案。 在這種情況下,3參數where不布爾陣列映射到所期望的1/0的一個很好的工作。 還有其他方法可以執行此映射。

您的雙循環需要添加一層編碼,即廣播的None

對於更復雜的示例,或者如果要處理的數組更大,或者可以寫入已經預先分配的數組,可以考慮使用Numba

import numba as nb
import numpy as np

@nb.njit()
def fun(x, y):
  if x > y:
    return 0
  else:
    return 1

@nb.njit(parallel=False)
#@nb.njit(parallel=True)
def loop(x,y):
  result=np.empty((x.shape[0],y.shape[0]),dtype=np.int32)
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

@nb.njit(parallel=False)
def loop_preallocated(x,y,result):
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

計時

x = np.array(range(1000))
y = np.array(range(1000))

#Compilation overhead of the first call is neglected

res=np.where(x[:,None]>y,0,1) -> 2.46ms
loop(single_threaded)         -> 1.23ms
loop(parallel)                -> 1.0ms
loop(single_threaded)*        -> 0.27ms
loop(parallel)*               -> 0.058ms

*可能受緩存影響。 測試您自己的示例。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM