[英]Turn numbers to parameters in sympy
我想將 sympy 表達式中的數字轉換為參數。 我有這個代碼:
import numpy as np
import torch
import sympy
from sympy import *
from sympy.abc import x,y
from sympy.parsing.sympy_parser import parse_expr
from sympy import Symbol, lambdify, N
def sympy_param(math_expr):
param_dict = {}
unsnapped_param_dict = {'p':1}
def unsnap_recur(expr, param_dict, unsnapped_param_dict):
"""Recursively transform each numerical value into a learnable parameter."""
import sympy
from sympy import Symbol
if isinstance(expr, sympy.numbers.Float) or isinstance(expr, sympy.numbers.Integer) or isinstance(expr, sympy.numbers.Rational) or isinstance(expr, sympy.numbers.Pi):
used_param_names = list(param_dict.keys()) + list(unsnapped_param_dict)
unsnapped_param_name = get_next_available_key(used_param_names, "p", is_underscore=False)
unsnapped_param_dict[unsnapped_param_name] = float(expr)
unsnapped_expr = Symbol(unsnapped_param_name)
return unsnapped_expr
elif isinstance(expr, sympy.symbol.Symbol):
return expr
else:
unsnapped_sub_expr_list = []
for sub_expr in expr.args:
unsnapped_sub_expr = unsnap_recur(sub_expr, param_dict, unsnapped_param_dict)
unsnapped_sub_expr_list.append(unsnapped_sub_expr)
return expr.func(*unsnapped_sub_expr_list)
def get_next_available_key(iterable, key, midfix="", suffix="", is_underscore=True):
"""Get the next available key that does not collide with the keys in the dictionary."""
if key + suffix not in iterable:
return key + suffix
else:
i = 0
underscore = "_" if is_underscore else ""
while "{}{}{}{}{}".format(key, underscore, midfix, i, suffix) in iterable:
i += 1
new_key = "{}{}{}{}{}".format(key, underscore, midfix, i, suffix)
return new_key
eq = parse_expr(math_expr)
eq = unsnap_recur(eq,param_dict,unsnapped_param_dict)
return eq
它適用於大多數情況。 例如,如果我運行:
math_expr = "3.1+exp(1.1*x0)-0.3*log(x1**7)"
print(sympy_param(math_expr))
我得到 output:
p0 + p1*log(x1**p2) + exp(p3*x0)
這就是我需要的。 但是,當我嘗試:
math_expr = "-5.4-1.6/x0"
print(sympy_param(math_expr))
我明白了:
p0 + p1*x0**p2
從技術上講,這並沒有錯,但是這種力量有點弄亂了我的整體代碼(在此之后的下一步)。 有沒有辦法阻止 p2 出現在那里(這並不是我等式中出現的真正明確的數字),所以理想情況下我想得到:
p0 + p1/x0
有人可以幫我弄這個嗎? 謝謝!
問題是 SymPy 將1/x
表示為Pow(x, -1)
。 您也可能會遇到類似x - y
的問題,SymPy 表示為Add(x, Mul(-1, y))
。 避免此問題的最佳方法是專門檢查這些情況(例如isinstance(expr, Pow) and expr.exp == -1
)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.