繁体   English   中英

卡尔达诺的公式不适用于 numpy?

[英]Cardano's formula not working with numpy?

--- 使用 python 3 ---

按照这里的方程,我试图找到任意三阶多项式的所有实根。 不幸的是,我的实现没有产生正确的结果,我找不到错误。 也许你能在眨眼间发现它并告诉我。

(如您所见,只有绿色曲线的根是错误的。)

最诚挚的问候

import numpy as np
def find_cubic_roots(a,b,c,d):
    # with ax³ + bx² + cx + d = 0
    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = (R + np.sqrt(D))**(1/3)
    T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)

    return result

您看到的示例不起作用:

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])

x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    roots = find_cubic_roots(a,b,c,d)
    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d), color = colors[i])
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

简单示例

输出:

[[ 2.50852567+0.j        -0.25426283+1.1004545j -0.25426283-1.1004545j]]
[[ 2.+0.j  0.+0.j  0.-0.j]]
[[ 1.51400399+1.46763129j  1.02750817-1.1867528j  -0.54151216-0.28087849j]]

这是我对解决方案的尝试。 R + np.sqrt(D)R - np.sqrt(D)为负的情况下,您的代码失败。 原因在这个帖子里 基本上如果你做a**(1/3)其中a是负数,numpy 会返回一个复数。 然而,我们实际上希望ST为实数,因为负实数的立方根只是一个负实数(让我们暂时忽略De Moivre定理,专注于代码而不是数学)。 解决它的方法是检查S是否为实数,将其转换为实数并将S传递给函数from scipy.special import cbrt 同样对于T 示例代码:

import numpy as np
import pdb
import math
from scipy.special import cbrt
def find_cubic_roots(a,b,c,d, bp = False):

    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = 0 #NEW CALCULATION FOR S STARTS HERE
    if np.isreal(R + np.sqrt(D)):
        S = cbrt(np.real(R + np.sqrt(D)))
    else:
        S = (R + np.sqrt(D))**(1/3)
    T = 0 #NEW CALCULATION FOR T STARTS HERE
    if np.isreal(R - np.sqrt(D)):
        T = cbrt(np.real(R - np.sqrt(D)))
    else:
        T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)
    #if bp:
        #pdb.set_trace()
    return result


import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])
x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    if d == 8:
        roots = find_cubic_roots(a,b,c,d, True)
    else:
        roots = find_cubic_roots(a,b,c,d)

    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d))
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

免责声明:输出 root 给出了一些警告,您可能可以忽略。 输出是正确的。 但是,由于某些原因,绘图显示了一个额外的根。 这可能是由于您的绘图代码。 不过印刷的根部看起来不错。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM