I want to turn the numbers in a sympy expression into parameters. I have this code:
import numpy as np
import torch
import sympy
from sympy import *
from sympy.abc import x,y
from sympy.parsing.sympy_parser import parse_expr
from sympy import Symbol, lambdify, N
def sympy_param(math_expr):
param_dict = {}
unsnapped_param_dict = {'p':1}
def unsnap_recur(expr, param_dict, unsnapped_param_dict):
"""Recursively transform each numerical value into a learnable parameter."""
import sympy
from sympy import Symbol
if isinstance(expr, sympy.numbers.Float) or isinstance(expr, sympy.numbers.Integer) or isinstance(expr, sympy.numbers.Rational) or isinstance(expr, sympy.numbers.Pi):
used_param_names = list(param_dict.keys()) + list(unsnapped_param_dict)
unsnapped_param_name = get_next_available_key(used_param_names, "p", is_underscore=False)
unsnapped_param_dict[unsnapped_param_name] = float(expr)
unsnapped_expr = Symbol(unsnapped_param_name)
return unsnapped_expr
elif isinstance(expr, sympy.symbol.Symbol):
return expr
else:
unsnapped_sub_expr_list = []
for sub_expr in expr.args:
unsnapped_sub_expr = unsnap_recur(sub_expr, param_dict, unsnapped_param_dict)
unsnapped_sub_expr_list.append(unsnapped_sub_expr)
return expr.func(*unsnapped_sub_expr_list)
def get_next_available_key(iterable, key, midfix="", suffix="", is_underscore=True):
"""Get the next available key that does not collide with the keys in the dictionary."""
if key + suffix not in iterable:
return key + suffix
else:
i = 0
underscore = "_" if is_underscore else ""
while "{}{}{}{}{}".format(key, underscore, midfix, i, suffix) in iterable:
i += 1
new_key = "{}{}{}{}{}".format(key, underscore, midfix, i, suffix)
return new_key
eq = parse_expr(math_expr)
eq = unsnap_recur(eq,param_dict,unsnapped_param_dict)
return eq
It works well on most cases. For example, if I run:
math_expr = "3.1+exp(1.1*x0)-0.3*log(x1**7)"
print(sympy_param(math_expr))
I get as output:
p0 + p1*log(x1**p2) + exp(p3*x0)
which is what I need. However when I try:
math_expr = "-5.4-1.6/x0"
print(sympy_param(math_expr))
I get this:
p0 + p1*x0**p2
It is not wrong technically, but that power kinda messes up my overall code (for the next step after this). Is there a way to prevent that p2 in the power to appear there (that is not really an explicit number appearing in my equation), so ideally I would like to get:
p0 + p1/x0
Can someone help me with this? Thank you!
The problem is that SymPy represents 1/x
as Pow(x, -1)
. You may also run into a similar issue with something like x - y
, which SymPy represents as Add(x, Mul(-1, y))
. The best way to avoid this issue is to check for those cases specifically (like isinstance(expr, Pow) and expr.exp == -1
).
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.