简体   繁体   中英

std::ratio power of a std::ratio at compile-time?

I have a challenging question from a mathematical, algorithmic and metaprogramming recursion point of view. Consider the following declaration:

template<class R1, class R2>
using ratio_power = /* to be defined */;

based on the example of the std::ratio operations like std::ratio_add . Given, two std::ratio R1 and R2 this operation should compute R1^R2 if and only if R1^R2 is a rational number. If it is irrational, then the implementation should fail, like when one try to multiply two very big ratios and the compiler say that there is an integer overflow.

Three questions:

  1. Do you think this is possible without exploding the compilation time?
  2. What algorithm to use?
  3. How to implement this operation?

You need two building blocks for this calculation:

  • the n-th power of an integer at compile-time
  • the n-th root of an integer at compile-time

Note: I use int as type for numerator and denominator to save some typing, I hope the main point comes across. I extract the following code from a working implementation but I cannot guarantee that I will not make a typo somewhere ;)

The first one is rather easy: You use x^(2n) = x^n * x^n or x^(2n+1) = x^n * x^n * x That way, you instantiate the least templates, eg x^39 be calculated something like that: x39 = x19 * x19 * x x19 = x9 * x9 * x x9 = x4 * x4 * x x4 = x2 * x2 x2 = x1 * x1 x1 = x0 * x x0 = 1

template <int Base, int Exponent>
struct static_pow
{
  static const int temp = static_pow<Base, Exponent / 2>::value;
  static const int value = temp * temp * (Exponent % 2 == 1 ? Base : 1);
};

template <int Base>
struct static_pow<Base, 0>
{
  static const int value = 1;
};

The second one is a bit tricky and works with a bracketing algorithm: Given x and N we want to find a number r so that r^N = x

  • set the interval [low, high] that contains the solution to [1, 1 + x / N]
  • calculate the midpoint mean = (low + high) / 2
  • determine, if mean^N >= x
    • if yes, set the interval to [low, mean]
    • if not, set the interval to [mean+1, high]
  • if the interval contains only one number, the calculation is finished
  • otherwise, iterate again

This algorithm gives the largest integer s that folfills s^N <= x

So check whether s^N == x. If yes, the N-th root of x is integral, otherwise not.

Now lets write that as compile-time program:

basic interface:

template <int x, int N>
struct static_root : static_root_helper<x, N, 1, 1 + x / N> { };

helper:

template <int x, int N, int low, int high>
struct static_root_helper
{
  static const int mean = (low + high) / 2;
  static const bool is_left = calculate_left<mean, N, x>::value;
  static const int value = static_root_helper<x, N, (is_left ? low : mean + 1), (is_left ? mean, high)>::value;
};

endpoint of recursion where the interval consists of only one entry:

template <int x, int N, int mid>
struct static_root_helper<x, N, mid, mid>
{
  static const int value = mid;
};

helper to detect multiplication overflow (You can exchange the boost-header for c++11 constexpr-numeric_limits, I think). Returns true, if the multiplication a * b would overflow.

#include "boost/integer_traits.hpp"

template <typename T, T a, T b>
struct mul_overflow
{
  static_assert(std::is_integral<T>::value, "T must be integral");
  static const bool value = (a > boost::integer_traits<T>::const_max / b);
};

Now we need to implement calculate_left that calculates whether the solution of x^N is left of mean or right of mean. We want to be able to calculate arbitrary roots so a naive implementation like static_pow > x will overflow very quickly and give wrong results. Therefore we use the following scheme: We want to calculate if x^N > B

  • set A = x and i = 1
  • if A >= B we are already finished -> A^N will surely be larger than B
  • will A * x overflow?
    • if yes -> A^N will surely be larger than B
    • if not -> A *= x and i += 1
  • if i == N, we are finished and we can do a simple comparison to B

now lets write this as metaprogram

template <int A, int N, int B>
struct calculate_left : calculate_left_helper<A, 1, A, N, B, (A >= B)> { };

template <int x, int i, int A, int N, int B, bool short_circuit>
struct calulate_left_helper
{
  static const bool overflow = mul_overflow<int, x, A>::value;
  static const int next = calculate_next<x, A, overflow>::value;
  static const bool value = calculate_left_helper<next, i + 1, A, N, B, (overflow || next >= B)>::value;
};

endpoint where i == N

template <int x, int A, int N, int B, bool short_circuit>
struct calculate_left_helper<x, N, A, N, B, short_circuit>
{
  static const bool value = (x >= B);
};

endpoints for short-circuit

template <int x, int i, int A, int N, int B>
struct calculate_down_helper<x, i, A, N, B, true>
{
  static const bool value = true;
};

template <int x, int A, int N, int B>
struct calculate_down_helper<x, N, A, N, B, true>
{
  static const bool value = true;
};

helper to calculate the next value of x * A, takex overflow into account to eliminate compiler warnings:

template <int a, int b, bool overflow>
struct calculate_next
{
  static const int value = a * b;
};

template <int a, int b>
struct calculate_next<a, b, true>
{
  static const int value = 0; // any value will do here, calculation will short-circuit anyway
};

So, that should be it. We need an additional helper

template <int x, int N>
struct has_integral_root
{
  static const int root = static_root<x, N>::value;
  static const bool value = (static_pow<root, N>::value == x);
};

Now we can implement ratio_pow as follows:

template <typename, typename> struct ratio_pow;

template <int N1, int D1, int N2, int D2>
struct ratio_pow<std::ratio<N1, D1>, std::ratio<N2, D2>>
{
  // ensure that all roots are integral
  static_assert(has_integral_root<std::ratio<N1, D1>::num, std::ratio<N2, D2>::den>::value, "numerator has no integral root");
  static_assert(has_integral_root<std::ratio<N1, D1>::den, std::ratio<N2, D2>::den>::value, "denominator has no integral root");
  // calculate the "D2"-th root of (N1 / D1)
  static const int num1 = static_root<std::ratio<N1, D1>::num, std::ratio<N2, D2>::den>::value;
  static const int den1 = static_root<std::ratio<N1, D1>::den, std::ratio<N2, D2>::den>::value;
  // exchange num1 and den1 if the exponent is negative and set the exp to the absolute value of the exponent
  static const bool positive_exponent = std::ratio<N2, D2>::num >= 0;
  static const int num2 = positive_exponent ? num1 : den1;
  static const int den2 = positive_exponent ? den1 : num1;
  static const int exp = positive_exponent ? std::ratio<N2, D2>::num : - std::ratio<N2, D2>::num;
  //! calculate (num2 / den2) ^ "N2"
  typedef std::ratio<static_pow<num2, exp>::value, static_pow<den2, exp>::value> type;
};

So, I hope at least the basic idea comes across.

Yes, it's possible.

Let's define R1 = P1/Q1, R2 = P2/Q2, and R1^R2 = R3 = P3/Q3. Assume further that P and Q are co-primes.

R1^R2 = R1^(P2/Q2) = R3
R1 ^ P2 = R3 ^ Q2.

R1^P2 is known and has a unique factoring into primes 2^a * 3^b * 5^c * ... Note that a, b, c can be negative as R1 is P1/Q1 . Now the first question is whether all of a,b,c are multiples of known factor Q2. If not, then you fail directly. If they are, then R3 = 2^(a/Q2) * 3^(b/Q2) * 5^(c/Q2) ... .

All divisions are either exact or the result does not exist, so we can use pure integer math in our templates. Factoring a number is fairly straightforward in templates (partial specialization on x%y==0 ).

Example: 2^(1/2) = R3 -> a=1, b=0, c=0, ... and a%2 != 0 -> impossible. (1/9)^(1/2) -> a=0, b=-2, b%2 = 0, possible, result = 3^(-2/2).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM