簡體   English   中英

將數字轉換為 sympy 中的參數

[英]Turn numbers to parameters in sympy

我想將 sympy 表達式中的數字轉換為參數。 我有這個代碼:

import numpy as np
import torch
import sympy

from sympy import *
from sympy.abc import x,y
from sympy.parsing.sympy_parser import parse_expr
from sympy import Symbol, lambdify, N

def sympy_param(math_expr):
    param_dict = {}
    unsnapped_param_dict = {'p':1}

    def unsnap_recur(expr, param_dict, unsnapped_param_dict):
        """Recursively transform each numerical value into a learnable parameter."""
        import sympy
        from sympy import Symbol
        if isinstance(expr, sympy.numbers.Float) or isinstance(expr, sympy.numbers.Integer) or isinstance(expr, sympy.numbers.Rational) or isinstance(expr, sympy.numbers.Pi):
            used_param_names = list(param_dict.keys()) + list(unsnapped_param_dict)
            unsnapped_param_name = get_next_available_key(used_param_names, "p", is_underscore=False)
            unsnapped_param_dict[unsnapped_param_name] = float(expr)
            unsnapped_expr = Symbol(unsnapped_param_name)
            return unsnapped_expr
        elif isinstance(expr, sympy.symbol.Symbol):
            return expr
        else:
            unsnapped_sub_expr_list = []
            for sub_expr in expr.args:
                unsnapped_sub_expr = unsnap_recur(sub_expr, param_dict, unsnapped_param_dict)
                unsnapped_sub_expr_list.append(unsnapped_sub_expr)
            return expr.func(*unsnapped_sub_expr_list)

    def get_next_available_key(iterable, key, midfix="", suffix="", is_underscore=True):
        """Get the next available key that does not collide with the keys in the dictionary."""
        if key + suffix not in iterable:
            return key + suffix
        else:
            i = 0
            underscore = "_" if is_underscore else ""
            while "{}{}{}{}{}".format(key, underscore, midfix, i, suffix) in iterable:
                i += 1
            new_key = "{}{}{}{}{}".format(key, underscore, midfix, i, suffix)
            return new_key

    eq = parse_expr(math_expr)
    eq = unsnap_recur(eq,param_dict,unsnapped_param_dict)
    return eq

它適用於大多數情況。 例如,如果我運行:

math_expr = "3.1+exp(1.1*x0)-0.3*log(x1**7)"
print(sympy_param(math_expr))

我得到 output:

p0 + p1*log(x1**p2) + exp(p3*x0)

這就是我需要的。 但是,當我嘗試:

math_expr = "-5.4-1.6/x0"
print(sympy_param(math_expr))

我明白了:

p0 + p1*x0**p2

從技術上講,這並沒有錯,但是這種力量有點弄亂了我的整體代碼(在此之后的下一步)。 有沒有辦法阻止 p2 出現在那里(這並不是我等式中出現的真正明確的數字),所以理想情況下我想得到:

p0 + p1/x0

有人可以幫我弄這個嗎? 謝謝!

問題是 SymPy 將1/x表示為Pow(x, -1) 您也可能會遇到類似x - y的問題,SymPy 表示為Add(x, Mul(-1, y)) 避免此問題的最佳方法是專門檢查這些情況(例如isinstance(expr, Pow) and expr.exp == -1 )。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM