简体   繁体   English

使用C ++中的表达式模板进行符号区分

[英]Symbolic differentiation using expression templates in C++

如何在C ++中使用表达式模板实现符号区分

In general you'd want a way to represent your symbols (ie the expressions templates that encode eg 3 * x * x + 42 ), and a meta-function that can compute a derivative. 通常,您需要一种表示符号的方法(即编码例如3 * x * x + 42的表达式模板),以及可以计算导数的元函数。 Hopefully you're familiar enough with metaprogramming in C++ to know what that means and entails but to give you an idea: 希望你对C ++中的元编程非常熟悉,知道这意味着什么,但要给你一个想法:

// This should come from the expression templates
template<typename Lhs, typename Rhs>
struct plus_node;

// Metafunction that computes a derivative
template<typename T>
struct derivative;

// derivative<foo>::type is the result of computing the derivative of foo

// Derivative of lhs + rhs
template<typename Lhs, typename Rhs>
struct derivative<plus_node<Lhs, Rhs> > {
    typedef plus_node<
        typename derivative<Lhs>::type
        , typename derivative<Rhs>::type
    > type;
};

// and so on

You'd then tie up the two parts (representation and computation) such that it would be convenient to use. 然后,您将绑定两个部分(表示和计算),以便使用方便。 Eg derivative(3 * x * x + 42)(6) could mean 'compute the derivative of 3 * x * x + 42 in x at 6'. 例如, derivative(3 * x * x + 42)(6)可以表示'计算6'处3 * x * x + 423 * x * x + 42的导数。

However even if you do know what it takes to write expression templates and what it takes to write a metaprogram in C++ I wouldn't recommend going about it this way. 但是,即使你知道编写表达式模板需要什么以及用C ++编写元程序需要什么,我也不建议这样做。 Template metaprogramming requires a lot of boilerplate and can be tedious. 模板元编程需要很多样板,并且可能很乏味。 Instead, I direct you to the genius Boost.Proto library, which is precisely designed to help write EDSLs (using expression templates) and operate on those expression templates. 相反,我将您引导到天才Boost.Proto库,它精确地设计用于帮助编写EDSL(使用表达式模板)并对这些表达式模板进行操作。 It it not necessarily easy to learn to use but I've found that learning how to achieve the same thing without using it is harder . 它不一定容易学会使用,但我发现学习如何在不使用它的情况下实现同样的事情更难 Here's a sample program that can in fact understand and compute derivative(3 * x * x + 42)(6) : 这是一个实际上可以理解和计算derivative(3 * x * x + 42)(6)的示例程序derivative(3 * x * x + 42)(6)

#include <iostream>

#include <boost/proto/proto.hpp>

using namespace boost::proto;

// Assuming derivative of one variable, the 'unknown'
struct unknown {};

// Boost.Proto calls this the expression wrapper
// elements of the EDSL will have this type
template<typename Expr>
struct expression;

// Boost.Proto calls this the domain
struct derived_domain
: domain<generator<expression>> {};

// We will use a context to evaluate expression templates
struct evaluation_context: callable_context<evaluation_context const> {
    double value;

    explicit evaluation_context(double value)
        : value(value)
    {}

    typedef double result_type;

    double operator()(tag::terminal, unknown) const
    { return value; }
};
// And now we can do:
// evalutation_context context(42);
// eval(expr, context);
// to evaluate an expression as though the unknown had value 42

template<typename Expr>
struct expression: extends<Expr, expression<Expr>, derived_domain> {
    typedef extends<Expr, expression<Expr>, derived_domain> base_type;

    expression(Expr const& expr = Expr())
        : base_type(expr)
    {}

    typedef double result_type;

    // We spare ourselves the need to write eval(expr, context)
    // Instead, expr(42) is available
    double operator()(double d) const
    {
        evaluation_context context(d);
        return eval(*this, context);
    }
};

// Boost.Proto calls this a transform -- we use this to operate
// on the expression templates
struct Derivative
: or_<
    when<
        terminal<unknown>
        , boost::mpl::int_<1>()
    >
    , when<
        terminal<_>
        , boost::mpl::int_<0>()
    >
    , when<
        plus<Derivative, Derivative>
        , _make_plus(Derivative(_left), Derivative(_right))
    >
    , when<
        multiplies<Derivative, Derivative>
        , _make_plus(
            _make_multiplies(Derivative(_left), _right)
            , _make_multiplies(_left, Derivative(_right))
        )
    >
    , otherwise<_>
> {};

// x is the unknown
expression<terminal<unknown>::type> const x;

// A transform works as a functor
Derivative const derivative;

int
main()
{
    double d = derivative(3 * x * x + 3)(6);
    std::cout << d << '\n';
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM