简体   繁体   English

卡尔达诺的公式不适用于 numpy?

[英]Cardano's formula not working with numpy?

--- using python 3 --- --- 使用 python 3 ---

Following the equations here , I tried to find all real roots of an arbitrary third-order-polynomial.按照这里的方程,我试图找到任意三阶多项式的所有实根。 Unfortunatelly, my implementation does not yield the correct result and I cannot find the error.不幸的是,我的实现没有产生正确的结果,我找不到错误。 Maybe you are able to spot it within a blink of an eye and tell me.也许你能在眨眼间发现它并告诉我。

(As you notice, only the roots of the green curve are wrong.) (如您所见,只有绿色曲线的根是错误的。)

With best regards最诚挚的问候

import numpy as np
def find_cubic_roots(a,b,c,d):
    # with ax³ + bx² + cx + d = 0
    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = (R + np.sqrt(D))**(1/3)
    T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)

    return result

The example where you see it does not work:您看到的示例不起作用:

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])

x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    roots = find_cubic_roots(a,b,c,d)
    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d), color = colors[i])
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

简单示例

Output:输出:

[[ 2.50852567+0.j        -0.25426283+1.1004545j -0.25426283-1.1004545j]]
[[ 2.+0.j  0.+0.j  0.-0.j]]
[[ 1.51400399+1.46763129j  1.02750817-1.1867528j  -0.54151216-0.28087849j]]

Here is my stab at the solution.这是我对解决方案的尝试。 Your code fails for the case where R + np.sqrt(D) or R - np.sqrt(D) is negative.R + np.sqrt(D)R - np.sqrt(D)为负的情况下,您的代码失败。 The reason is in this post .原因在这个帖子里 Basically if you do a**(1/3) where a is negative, numpy returns a complex number.基本上如果你做a**(1/3)其中a是负数,numpy 会返回一个复数。 However, we infact, want S and T to be real since cube root of a negative real number is simply a negative real number (let's ignore De Moivre's theorem for now and focus on the code and not the math).然而,我们实际上希望ST为实数,因为负实数的立方根只是一个负实数(让我们暂时忽略De Moivre定理,专注于代码而不是数学)。 The way to work around it is to check if S is real, cast it to real and pass S to the function from scipy.special import cbrt .解决它的方法是检查S是否为实数,将其转换为实数并将S传递给函数from scipy.special import cbrt Similarly for T .同样对于T Example code:示例代码:

import numpy as np
import pdb
import math
from scipy.special import cbrt
def find_cubic_roots(a,b,c,d, bp = False):

    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = 0 #NEW CALCULATION FOR S STARTS HERE
    if np.isreal(R + np.sqrt(D)):
        S = cbrt(np.real(R + np.sqrt(D)))
    else:
        S = (R + np.sqrt(D))**(1/3)
    T = 0 #NEW CALCULATION FOR T STARTS HERE
    if np.isreal(R - np.sqrt(D)):
        T = cbrt(np.real(R - np.sqrt(D)))
    else:
        T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)
    #if bp:
        #pdb.set_trace()
    return result


import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])
x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    if d == 8:
        roots = find_cubic_roots(a,b,c,d, True)
    else:
        roots = find_cubic_roots(a,b,c,d)

    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d))
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

DISCLAIMER: The output root gives some warning, which you can probably ignore.免责声明:输出 root 给出了一些警告,您可能可以忽略。 The output is correct.输出是正确的。 However, the plotting shows an extra root for some reasons.但是,由于某些原因,绘图显示了一个额外的根。 This is likely due to your plotting code.这可能是由于您的绘图代码。 The printed roots look fine though.不过印刷的根部看起来不错。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM