[英]How to use np.vectorize?
我有這個 function 來矢量化:
if x >= y, then x*y
else x/y
我的代碼是:
def vector_function(x, y):
if y >= x:
return x*y
else:
return x/y
vfunc = np.vectorize(vector_function)
return vfunc
raise NotImplementedError
但我收到錯誤:
'>=' not supported between instances of 'int' and 'list'
有人可以幫忙嗎?
問題是 function內部的vectorize
調用。
import numpy as np
# first define the function
def vector_function(x, y):
if y >= x:
return x * y
else:
return x / y
# vectorize it
vfunc = np.vectorize(vector_function)
# validation
print(vfunc([1, 2, 3, 4], 2.5)) # [2.5 5. 1.2 1.6]
但是請注意,從numpy.vectorize 開始: vectorize
function 主要是為了方便,而不是為了性能。 該實現本質上是一個 for 循環。
純“矢量化”版本是:
def foo(x,y):
return np.where(y>=x, x*y, x/y)
In [317]: foo(np.array([1,2,3,4]), 2.5)
Out[317]: array([2.5, 5. , 1.2, 1.6])
根據 arrays 的大小,這比Stefans 的回答快 2 到 10 倍
我選擇了這種where
方法,因為它是用y
廣播x
的最簡單和最緊湊的方法。 它可能不是最快的,具體取決於/
和*
的“成本”。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.