简体   繁体   English

使用 Sympy 生成 C 代码。 将 Pow(x,2) 替换为 x*x

[英]Generate C code with Sympy. Replace Pow(x,2) by x*x

I am generating C code with sympy the using the Common Subexpression Elimination (CSE) routine and the ccode printer.我正在使用通用子表达式消除 (CSE) 例程和 ccode 打印机生成 C 代码。

However, I would like to have powered expressions as (x*x) instead of pow(x,2).但是,我希望将动力表达式作为 (x*x) 而不是 pow(x,2)。

Anyway to do this?无论如何要这样做?

Example:例子:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1], tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))

Will output:请问output:

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;

Best Regards此致

You can subclass the code printer , and only change the one function you want different.您可以对代码打印机进行子类化,并且只更改您想要不同的一个 function。 You'd need to investigate the original sympy code to find the correct function names and default implementation, so you can make sure you don't make errors.您需要调查原始 sympy 代码以找到正确的 function 名称和默认实现,因此您可以确保不会出错。 With a bit of care, the needed brackets can be generated automatically exact when and where they are needed.稍加注意,可以在需要的时间和地点自动生成所需的括号。

Here is a minimal example:这是一个最小的例子:

import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp == 2:
            return '({0} * {0})'.format(self.parenthesize(expr.base, PREC))
        else:
            return super()._print_Pow(expr)

default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint

expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))

Output: Output:

Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]

PS: To support a wider range of exponents, you could use eg PS:为了支持更广泛的指数,您可以使用例如

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp in range(2, 7):
            return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
        elif expr.exp in range(-6, 0):
            return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
        else:
            return super()._print_Pow(expr)

There is a function called create_expand_pow_optimization that creates a wrapper to optimise your expressions in this respect.有一个名为create_expand_pow_optimization的 function 创建了一个包装器来优化您在这方面的表达式。 It takes as an argument the highest power it will replace by explicit multiplications.它将被显式乘法取代的最高幂作为参数。

The wrapper returns an UnevaluatedExpr that is protected against automatic simplifications that would revert this change.包装器返回一个UnevaluatedExpr ,它受到保护,不会自动简化会恢复此更改。

import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization

expand_opt = create_expand_pow_optimization(3)

a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)

for i,result in enumerate(res[1]):
    print(sp.ccode(expand_opt(result),"result_%i"%i))

Finally, be aware that for sufficiently high optimisation levels, your compiler will take care of this (and is probably better at this).最后,请注意,对于足够高的优化级别,您的编译器会处理这个问题(并且可能在这方面做得更好)。

I think I will go with the user_function approach:我想我会用 user_function 方法 go :

As suggested in the comment above I will be using the user_functions functionality of sp.ccode: Assuming we have a number like a^3正如上面评论中所建议的,我将使用 sp.ccode 的user_functions功能:假设我们有一个像a^3这样的数字

sp.ccode(a**3, user_functions={'Pow': [(lambda x, y: y.is_integer, lambda x, y: '*'.join(['('+x+')']*int(y))),(lambda x, y: not y.is_integer, 'pow')]})

should output: '(a)*(a)*(a)'应该 output: '(a)*(a)*(a)'

In the future, I will try to improve the function to only put parenthesis when needed.将来,我会尝试改进 function 以仅在需要时放置括号。

Any improvements are welcome!欢迎任何改进!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM