简体   繁体   中英

NaN type in python after sympy substitution

I try to substitute variables in a sympy expression but I get NaN type after substitution and I don't understand why.

Here is the code:

import sympy as sp
import copy
import numpy as np
import itertools as it
import matplotlib.pyplot as plt

alpha_set_values = np.linspace(0, 5, 10000)
beta_set_values = np.linspace(0, 6, 10000)

def plot_expr(exprVal, points):
  for point in points:
    value = exprVal.subs( [ (beta,point[0]), (alpha, point[1]) ] )
    print(type(value))
    if value > 0:
      plt.scatter([beta], [alpha], color = 'r')
    else:
      plt.scatter([beta], [alpha], color = 'b')

  plt.show()
plot_expr(expr1, points)

expr1 is a sympy expression with symbols alpha and beta (α*(1 - 0.1/β) + α - 0.3 α/β + 2 - 1.9β - 0.1 α - β)/β). After substitution, datatype of value is NaN

For the full code here is google colab link. The last 2 cells are important and must be run - the error lacks in the last cell

You are getting Nan because your first point is (0, 0) . Take a look at your expression, expr1 :

α*(1 - 0.1/β) + α - 0.3*α/β + 2 - 1.9*(α*β - 0.1*α - β)/β

In particular, the terms:

  • α*(1 - 0.1/β) : after the substitution, sympy evaluates 0 * (1 - zoo) which results in Nan .
  • also, α/β might results in NaN .

I see what you are trying to do: the best way to achieve your goal is to use plot_implicit :

sp.plot_implicit(expr1 > 0, (beta, 0, 6), (alpha, 0, 5))

Alternatively, if you'd like to go on with your approach, consider starting from a slightly different value than zero. Also, to speed up computation and plotting, use sp.lambdify to convert the symbolic expression to a numerical function:

n = 100
alpha_set_values = np.linspace(1e-06, 5, n)
beta_set_values = np.linspace(1e-06, 6, n)
alpha_set_values, beta_set_values = np.meshgrid(alpha_set_values, beta_set_values)

f = sp.lambdify([alpha, beta], expr1)
res = f(alpha_set_values, beta_set_values)

alpha_set_values = alpha_set_values.flatten()
beta_set_values = beta_set_values.flatten()
res = res.flatten()
idx = res > 0

plt.figure()
plt.scatter(beta_set_values[idx], alpha_set_values[idx], color="r")
plt.scatter(beta_set_values[~idx], alpha_set_values[~idx], color="b")
plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM