繁体   English   中英

根据另一个变量参数包查找变量包装的收缩

[英]Find contractions of a variadic pack based on another variadic parameter pack

我正在使用静态多维数组压缩框架,但遇到了一个难以解释的问题,但我会尽力而为。 假设我们有一个N维数组类

template<typename T, int ... dims>
class Array {}

可以实例化为

Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on

现在我们有另一个类叫做Indices

template<int ... Idx>
struct Indices {}

我现在有一个函数contraction其签名应如下所示

template<T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>> 
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)

我可能没有在这里得到语法,但是我实质上希望返回的Array具有基于Indices条目的维。 让我提供contraction可以执行的示例。 注意,在这种情况下, 收缩表示删除索引列表中参数相等的尺寸

auto arr = contraction(Indices<0,0>, Array<double,3,3>) 
// arr is Array<double> as both indices contract 0==0

auto arr = contraction(Indices<0,1>, Array<double,3,3>) 
// arr is Array<double,3,3> as no contraction happens here, 0!=1

auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>) 
// arr is Array<double,4> as 1st and 3rd indices contract 0==0  

auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>) 
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract

auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments 
// requested but dimensions don't match 5!=6

// The parameters of Indices really do not matter as long as 
// we can identify contractions. They are typically expressed as enums, I,J,K...

因此,从本质Idx... ,给定Idx...Dims...的大小都应相等,请检查Idx...中的哪些值相等,获取它们出现的位置并删除Dims...中的相应条目(位置) Dims... 这本质上是张量收缩规则

数组收缩规则:

  1. 索引的参数数和数组的维数/秩应相同,即sizeof...(Idx)==sizeof...(Dims)
  2. 一对一对应bewteen IdxDims即,如果我们具有Indices<0,1,2>Array<double,4,5,6> 0映射到41映射到52映射到6
  3. 如果Idx存在相同/相等的值,则表示收缩,这意味着Dims的相应尺寸应消失,例如,如果我们有Indices<0,0,3>Array<double,4,4,6> ,然后0==0并且这些值映射到的相应维度44都需要消失,并且结果数组应为Array<double,6>
  4. 如果Idx具有相同的值,但对应的Dims不匹配,则应触发编译时错误,例如, Indices<0,0,3>Array<double,4,5,6>不可能,因为4!=5 ,类似地,由于4!=6Indices<0,1,0>是不可能的,这导致
  5. 不同尺寸的数组无法收缩,例如Array<double,4,5,6>不能以任何方式收缩。
  6. 只要对应的Dims也匹配, Idx就可以使用多对,三胞胎,四胞胎等,例如Indices<0,0,0,0,1,1,4,3,3,7,7,7>如果输入数组为Array<double,2,2,2,2,3,3,6,2,2,3,3,3> Indices<0,0,0,0,1,1,4,3,3,7,7,7>会收缩为Array<double,6> Array<double,2,2,2,2,3,3,6,2,2,3,3,3>

我对元编程的了解对实现此功能没有多大帮助,但是我希望我已经明确了意图,以便有人指导我朝正确的方向发展。

一堆执行实际检查的constexpr函数:

// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
    return cur == N ? true : 
           (cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}

// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
                            const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
    return cur == N ? true :
           (ind[cur] != index || dim[cur] == dimension) ? 
                check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}

// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
    return is_uniq(ind, i) ? dim[i] :
           check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}

现在,我们需要一种摆脱-1 s的方法:

template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
    :  concat<Indices<I1..., I2...>, Inds...> {};

// filter out all instances of I from Is...,
// return the rest as an Indices    
template<int I, int... Is>
struct filter
    :  concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};

使用它们:

template<class Ind, class Arr, class Seq>
struct contraction_impl;

template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
    static constexpr int ind[] = { Ind... };
    static constexpr int dim[] = { Dim... };
    static constexpr int result[] = {calc(Seq, ind, dim)...};

    template<int... Dims>
    static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;

    using type = decltype(unpack_helper(typename filter<-1,  result[Seq]...>::type{}));
};


template<class T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>, 
                          std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);

除了make_index_sequence之外的所有东西都是C ++ 11。 您可以在SO上找到大量的实现。

这是一团糟,但我认为它可以完成您想要的操作。 几乎可以肯定,可以对此进行许多简化,但这是我通过测试的第一遍。 请注意,这不会实现收缩,而只是确定应为哪种类型。 如果那不是您所需要的,我事先表示歉意。

#include <type_traits>

template <std::size_t...>
struct Indices {};

template <typename, std::size_t...>
struct Array {};

// Count number of 'i' in 'rest...', base case
template <std::size_t i, std::size_t... rest>
struct Count : std::integral_constant<std::size_t, 0>
{};

// Count number of 'i' in 'rest...', inductive case
template <std::size_t i, std::size_t j, std::size_t... rest>
struct Count<i, j, rest...> :
    std::integral_constant<std::size_t,
                           Count<i, rest...>::value + ((i == j) ? 1 : 0)>
{};

// Is 'i' contained in 'rest...'?
template <std::size_t i, std::size_t... rest>
struct Contains :
    std::integral_constant<bool, (Count<i, rest...>::value > 0)>
{};


// Accumulation of counts of indices in all, base case
template <typename All, typename Remainder,
          typename AccIdx, typename AccCount>
struct Counts {
    using indices = AccIdx;
    using counts = AccCount;
};

// Accumulation of counts of indices in all, inductive case
template <std::size_t... all, std::size_t i, std::size_t... rest,
          std::size_t... indices, std::size_t... counts>
struct Counts<Indices<all...>, Indices<i, rest...>,
              Indices<indices...>, Indices<counts...>>
    : std::conditional<Contains<i, indices...>::value,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices...>,
                              Indices<counts...>>,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices..., i>,
                              Indices<counts...,
                                      Count<i, all...>::value>>>::type
{};

// Get value in From that matched the first value of Idx that matched idx
template <std::size_t idx, typename Idx, typename From>
struct First : std::integral_constant<std::size_t, 0>
{};
template <std::size_t i, std::size_t j, std::size_t k,
          std::size_t... indices, std::size_t... values>
struct First<i, Indices<j, indices...>, Indices<k, values...>>
    : std::conditional<i == j,
                       std::integral_constant<std::size_t, k>,
                       First<i, Indices<indices...>,
                             Indices<values...>>>::type
{};

// Return whether all values in From that match Idx being idx are tgt
template <std::size_t idx, std::size_t tgt, typename Idx, typename From>
struct AllMatchTarget : std::true_type
{};
template <std::size_t idx, std::size_t tgt,
          std::size_t i, std::size_t j,
          std::size_t... indices, std::size_t... values>
struct AllMatchTarget<idx, tgt,
                      Indices<i, indices...>, Indices<j, values...>>
    : std::conditional<i == idx && j != tgt, std::false_type,
                       AllMatchTarget<idx, tgt, Indices<indices...>,
                                      Indices<values...>>>::type
{};

/* Generate the dimensions, given the counts, indices, and values */
template <typename Counts, typename Indices,
          typename AllIndices, typename Values, typename Accum>
struct GenDims;

template <typename A, typename V, typename R>
struct GenDims<Indices<>, Indices<>, A, V, R> {
    using type = R;
};
template <typename T, std::size_t i, std::size_t c,
          std::size_t... counts, std::size_t... indices,
          std::size_t... dims, typename AllIndices, typename Values>
struct GenDims<Indices<c, counts...>, Indices<i, indices...>,
               AllIndices, Values, Array<T, dims...>>
{
    static constexpr auto value = First<i, AllIndices, Values>::value;
    static_assert(AllMatchTarget<i, value, AllIndices, Values>::value,
                  "Index doesn't correspond to matching dimensions");
    using type = typename GenDims<
        Indices<counts...>, Indices<indices...>,
        AllIndices, Values,
        typename std::conditional<c == 1,
                                  Array<T, dims..., value>,
                                  Array<T, dims...>>::type>::type;
};

/* Put it all together */
template <typename I, typename A>
struct ContractionType;

template <typename T, std::size_t... indices, std::size_t... values>
struct ContractionType<Indices<indices...>, Array<T, values...>> {
    static_assert(sizeof...(indices) == sizeof...(values),
                   "Number of indices and dimensions do not match");
    using counts = Counts<Indices<indices...>,
                          Indices<indices...>,
                          Indices<>, Indices<>>;
    using type = typename GenDims<typename counts::counts,
                                  typename counts::indices,
                                  Indices<indices...>, Indices<values...>,
                                  Array<T>>::type;
};

static_assert(std::is_same<typename
              ContractionType<Indices<0, 0>, Array<double, 3, 3>>::type,
              Array<double>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1>, Array<double, 3, 3>>::type,
              Array<double, 3, 3>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0>, Array<double, 3, 4, 3>>::type,
              Array<double, 4>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0, 7, 7, 2>,
              Array<double, 3, 4, 3, 5, 5, 6>>::type,
              Array<double, 4, 6>>::value, "");

// Errors appropriately when uncommented
/* static_assert(std::is_same<typename */
/*               ContractionType<Indices<10,10, 2, 3>, */
/*               Array<double, 5,6,4,4>>::type, */
/*               Array<double>::value, ""); */

以下是对这里发生的情况的解释:

  • 首先,我使用Counts生成唯一索引的列表( Counts::indices )以及每个索引出现在序列中的次数( Counts::counts )。
  • 然后,我遍历索引,从Counts对数进行Counts ,对于每个索引,如果计数为1,我将累加值并递归。 否则,我将累加值继续传递并递归。

GenDims部分是GenDimsstatic_assert ,它会验证索引是否所有匹配的尺寸都相同。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM