簡體   English   中英

如何使用 numpy 在 Python 中定義分段函數?

[英]How to define piecewise function in Python using numpy?

以下是我想在python中實現的功能。 定義函數時出現類型錯誤。 我嘗試使用numpy.piecewise函數對象進行定義,並僅使用elif命令作為定義。 我希望能夠在不同的點以及f(X-1)等表達式評估這個函數

這是我的代碼:

from numpy import piecewise 
from scipy import *
from sympy.abc import x
from sympy.utilities.lambdify import lambdify, implemented_function
from sympy import Function
from sympy import *
h = 0.5 
a = -1
n = 2
x = Symbol('x')
expr = piecewise((0, x-a <=  -2*h), ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h), (2*h**3/3-0.5*(x-a)**2*(2*h+(x-a)), -h<= x-a<= 0), (2*(h**3/3)-0.5*(x-a)**2*(2*h+(x-a)), 0<=x-a<=2*h), ((1/6)*(2*h-(x-a))**3, h<=x-a<=2*h), (0, x-a<=2*h))
p = lambdify((x, a,b,h), expr)

def basis(x,a,b, h):
    if x <= a-2*h:
        return 0;
    elif (x<=a-h) or (x >=2*h):
        return (1/6)*(2*h+(x-a))**3
    elif  (x-a<= 0) or (x-a >= -h):
        return (2*h**3/3-0.5*(x-a)**2*(2*h+(x-a)));
    elif (x<=2*h+a) or (x >= 0):
        return  (2*(h**3/3)-0.5*(x-a)**2*(2*h+(x-a)));
    elif (x<=a+2*h) or (x >= h):
        return (1/6)*(2*h-(x-a))**3; 
    elif x-a<=2*h:
        return 0

basis(x, -1,0.5,0)

我得到的兩種方式:

raise TypeError("cannot determine truth value of Relational")

TypeError: cannot determine truth value of Relational

您可以使用 sympy 的 lamdify 函數來生成 numpy 分段函數。 這是一個更簡單的示例,但顯示了總體思路:

In [15]: from sympy import symbols, Piecewise                                                                                               

In [16]: x, a = symbols('x, a')                                                                                                   

In [17]: expr = Piecewise((x, x>a), (0, True))                                                                                    

In [18]: expr                                                                                                                     
Out[18]: 
⎧x  for a < x
⎨            
⎩0  otherwise

In [19]: from sympy import lambdify                                                                                               

In [20]: fun = lambdify((x, a), expr)                                                                                             

In [21]: fun([1, 3], [4, 2])                                                                                                      
Out[21]: array([0., 3.])

In [22]: import inspect                                                                                                           

In [23]: print(inspect.getsource(fun))                                                                                            
def _lambdifygenerated(x, a):
    return (select([less(a, x),True], [x,0], default=nan))

很抱歉這個答案的長度,但我認為您需要查看完整的調試過程。 我不得不查看回溯並測試您的代碼的一小部分以確定確切的問題。 我見過很多numpy歧義錯誤,但沒有看到這種sympy關系錯誤。

===

讓我們看看整個回溯,而不僅僅是其中的一行。 至少我們需要確定您的哪一行代碼產生了問題。

In [4]: expr = np.piecewise((0, x-a <=  -2*h), ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<
   ...: =-h), (2*h**3/3-0.5*(x-a)**2*(2*h+(x-a)), -h<= x-a<= 0), (2*(h**3/3)-0.5
   ...: *(x-a)**2*(2*h+(x-a)), 0<=x-a<=2*h), ((1/6)*(2*h-(x-a))**3, h<=x-a<=2*h)
   ...: , (0, x-a<=2*h))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-893bb4b36321> in <module>
----> 1 expr = np.piecewise((0, x-a <=  -2*h), ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h), (2*h**3/3-0.5*(x-a)**2*(2*h+(x-a)), -h<= x-a<= 0), (2*(h**3/3)-0.5*(x-a)**2*(2*h+(x-a)), 0<=x-a<=2*h), ((1/6)*(2*h-(x-a))**3, h<=x-a<=2*h), (0, x-a<=2*h))

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

雖然np.piecewise是一個 numpy 函數,因為x是一個sympy.Symbol ,方程是 sympy 表達式。 numpysympy沒有很好地集成。 有些東西有效,許多其他的則無效。

你嘗試過小表情嗎? 好的編程習慣是從小塊開始,確保它們首先起作用。

讓我們嘗試更小的東西:

In [8]: expr = np.piecewise((0, x-a <=  -2*h),
   ...:  ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-37ff62e49efb> in <module>
      1 expr = np.piecewise((0, x-a <=  -2*h),
----> 2  ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h))

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

和較小的部分:

In [10]: (0, x-a <=  -2*h)
Out[10]: (0, x + 1 ≤ -1.0)

In [11]: ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-7bd9f95d077d> in <module>
----> 1 ((1/6)*(2*h+(x-a))**3, -2*h<=x-a<=-h)

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

In [12]: (1/6)*(2*h+(x-a))**3
Out[12]: 
                            3
1.33333333333333⋅(0.5⋅x + 1) 

但:

In [13]: -2*h<=x-a<=-h
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-13-5ffb419cd443> in <module>
----> 1 -2*h<=x-a<=-h

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

進一步簡化:

In [14]: 0 < x < 3
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-14-59ba4ce00627> in <module>
----> 1 0 < x < 3

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

雖然a < b < c允許用於常規 Python 變量和標量,但它不適用於numpy數組,顯然也不適用於sympy變量。

所以眼前的問題與numpy無關。 您正在使用無效的sympy表達式!

===

您的basis函數揭示了同一問題的一個方面。 我們再次需要查看完整的回溯,然后測試部分以識別確切的問題表達。

In [16]: basis(x, -1,0.5,0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-16-b328f95b3c79> in <module>
----> 1 basis(x, -1,0.5,0)

<ipython-input-15-c6436540e3f3> in basis(x, a, b, h)
      1 def basis(x,a,b, h):
----> 2     if x <= a-2*h:
      3         return 0;
      4     elif (x<=a-h) or (x >=2*h):
      5         return (1/6)*(2*h+(x-a))**3

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

這個表達式是一個sympy關系:

In [17]: x <= -1
Out[17]: x ≤ -1

但是我們不能在 Python 的if語句中使用這樣的關系。

In [18]: if x <= -1: pass
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-b56148a48367> in <module>
----> 1 if x <= -1: pass

/usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __nonzero__(self)
    382 
    383     def __nonzero__(self):
--> 384         raise TypeError("cannot determine truth value of Relational")
    385 
    386     __bool__ = __nonzero__

TypeError: cannot determine truth value of Relational

Python if是簡單的真/假切換; 它的論點必須評估為一個或另一個。 錯誤告訴我們sympy.Relational不起作用。 0 < x < 1是基本 Python if變體(它測試0<xx<1並執行 a and )。

我們經常在numpy (和pandas )中看到的一個變體是:

In [20]: 0 < np.array([0,1,2])
Out[20]: array([False,  True,  True])

In [21]: 0 < np.array([0,1,2])<1
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-21-bc1039cec1fc> in <module>
----> 1 0 < np.array([0,1,2])<1

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

numpy表達式有多個 True/False 值,不能在需要簡單 True/False 的 Python 表達式中使用。

編輯

正確展開兩側測試:

In [23]: expr = np.piecewise((0, x-a <=  -2*h),
    ...:  ((1/6)*(2*h+(x-a))**3, (-2*h<=x-a)&(x-a<=-h)),
    ...:  (2*h**3/3-0.5*(x-a)**2*(2*h+(x-a)), (-h<= x-a)&(x-a<= 0)),
    ...:  (2*(h**3/3)-0.5*(x-a)**2*(2*h+(x-a)), (0<=x-a)&(x-a<=2*h)),
    ...:  ((1/6)*(2*h-(x-a))**3, (h<=x-a)&(x-a<=2*h)), (0, x-a<=2*h))

In [24]: expr
Out[24]: 
array([-0.5*(x + 1)**2*(x + 2.0) + 0.0833333333333333,
       -0.5*(x + 1)**2*(x + 2.0) + 0.0833333333333333], dtype=object)

In [26]: p = lambdify((x,), expr)

xexpr唯一的符號。

查看生成的函數:

In [27]: print(p.__doc__)
Created with lambdify. Signature:

func(x)

Expression:

[-0.5*(x + 1)**2*(x + 2.0) + 0.0833333333333333  -0.5*(x + 1)**2*(x + 2.0)...

Source code:

def _lambdifygenerated(x):
    return ([-0.5*(x + 1)**2*(x + 2.0) + 0.0833333333333333, -0.5*(x + 1)**2*(x + 2.0) + 0.0833333333333333])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM