简体   繁体   中英

Cardano's formula not working with numpy?

--- using python 3 ---

Following the equations here , I tried to find all real roots of an arbitrary third-order-polynomial. Unfortunatelly, my implementation does not yield the correct result and I cannot find the error. Maybe you are able to spot it within a blink of an eye and tell me.

(As you notice, only the roots of the green curve are wrong.)

With best regards

import numpy as np
def find_cubic_roots(a,b,c,d):
    # with ax³ + bx² + cx + d = 0
    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = (R + np.sqrt(D))**(1/3)
    T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)

    return result

The example where you see it does not work:

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])

x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    roots = find_cubic_roots(a,b,c,d)
    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d), color = colors[i])
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

简单示例

Output:

[[ 2.50852567+0.j        -0.25426283+1.1004545j -0.25426283-1.1004545j]]
[[ 2.+0.j  0.+0.j  0.-0.j]]
[[ 1.51400399+1.46763129j  1.02750817-1.1867528j  -0.54151216-0.28087849j]]

Here is my stab at the solution. Your code fails for the case where R + np.sqrt(D) or R - np.sqrt(D) is negative. The reason is in this post . Basically if you do a**(1/3) where a is negative, numpy returns a complex number. However, we infact, want S and T to be real since cube root of a negative real number is simply a negative real number (let's ignore De Moivre's theorem for now and focus on the code and not the math). The way to work around it is to check if S is real, cast it to real and pass S to the function from scipy.special import cbrt . Similarly for T . Example code:

import numpy as np
import pdb
import math
from scipy.special import cbrt
def find_cubic_roots(a,b,c,d, bp = False):

    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = 0 #NEW CALCULATION FOR S STARTS HERE
    if np.isreal(R + np.sqrt(D)):
        S = cbrt(np.real(R + np.sqrt(D)))
    else:
        S = (R + np.sqrt(D))**(1/3)
    T = 0 #NEW CALCULATION FOR T STARTS HERE
    if np.isreal(R - np.sqrt(D)):
        T = cbrt(np.real(R - np.sqrt(D)))
    else:
        T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)
    #if bp:
        #pdb.set_trace()
    return result


import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])
x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    if d == 8:
        roots = find_cubic_roots(a,b,c,d, True)
    else:
        roots = find_cubic_roots(a,b,c,d)

    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d))
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

DISCLAIMER: The output root gives some warning, which you can probably ignore. The output is correct. However, the plotting shows an extra root for some reasons. This is likely due to your plotting code. The printed roots look fine though.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM