简体   繁体   中英

Turn numbers to parameters in sympy

I want to turn the numbers in a sympy expression into parameters. I have this code:

import numpy as np
import torch
import sympy

from sympy import *
from sympy.abc import x,y
from sympy.parsing.sympy_parser import parse_expr
from sympy import Symbol, lambdify, N

def sympy_param(math_expr):
    param_dict = {}
    unsnapped_param_dict = {'p':1}

    def unsnap_recur(expr, param_dict, unsnapped_param_dict):
        """Recursively transform each numerical value into a learnable parameter."""
        import sympy
        from sympy import Symbol
        if isinstance(expr, sympy.numbers.Float) or isinstance(expr, sympy.numbers.Integer) or isinstance(expr, sympy.numbers.Rational) or isinstance(expr, sympy.numbers.Pi):
            used_param_names = list(param_dict.keys()) + list(unsnapped_param_dict)
            unsnapped_param_name = get_next_available_key(used_param_names, "p", is_underscore=False)
            unsnapped_param_dict[unsnapped_param_name] = float(expr)
            unsnapped_expr = Symbol(unsnapped_param_name)
            return unsnapped_expr
        elif isinstance(expr, sympy.symbol.Symbol):
            return expr
        else:
            unsnapped_sub_expr_list = []
            for sub_expr in expr.args:
                unsnapped_sub_expr = unsnap_recur(sub_expr, param_dict, unsnapped_param_dict)
                unsnapped_sub_expr_list.append(unsnapped_sub_expr)
            return expr.func(*unsnapped_sub_expr_list)

    def get_next_available_key(iterable, key, midfix="", suffix="", is_underscore=True):
        """Get the next available key that does not collide with the keys in the dictionary."""
        if key + suffix not in iterable:
            return key + suffix
        else:
            i = 0
            underscore = "_" if is_underscore else ""
            while "{}{}{}{}{}".format(key, underscore, midfix, i, suffix) in iterable:
                i += 1
            new_key = "{}{}{}{}{}".format(key, underscore, midfix, i, suffix)
            return new_key

    eq = parse_expr(math_expr)
    eq = unsnap_recur(eq,param_dict,unsnapped_param_dict)
    return eq

It works well on most cases. For example, if I run:

math_expr = "3.1+exp(1.1*x0)-0.3*log(x1**7)"
print(sympy_param(math_expr))

I get as output:

p0 + p1*log(x1**p2) + exp(p3*x0)

which is what I need. However when I try:

math_expr = "-5.4-1.6/x0"
print(sympy_param(math_expr))

I get this:

p0 + p1*x0**p2

It is not wrong technically, but that power kinda messes up my overall code (for the next step after this). Is there a way to prevent that p2 in the power to appear there (that is not really an explicit number appearing in my equation), so ideally I would like to get:

p0 + p1/x0

Can someone help me with this? Thank you!

The problem is that SymPy represents 1/x as Pow(x, -1) . You may also run into a similar issue with something like x - y , which SymPy represents as Add(x, Mul(-1, y)) . The best way to avoid this issue is to check for those cases specifically (like isinstance(expr, Pow) and expr.exp == -1 ).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM