[英]Find contractions of a variadic pack based on another variadic parameter pack
我正在使用静态多维数组压缩框架,但遇到了一个难以解释的问题,但我会尽力而为。 假设我们有一个N
维数组类
template<typename T, int ... dims>
class Array {}
可以实例化为
Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on
现在我们有另一个类叫做Indices
template<int ... Idx>
struct Indices {}
我现在有一个函数contraction
其签名应如下所示
template<T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>>
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)
我可能没有在这里得到语法,但是我实质上希望返回的Array
具有基于Indices
条目的维。 让我提供contraction
可以执行的示例。 注意,在这种情况下, 收缩表示删除索引列表中参数相等的尺寸 。
auto arr = contraction(Indices<0,0>, Array<double,3,3>)
// arr is Array<double> as both indices contract 0==0
auto arr = contraction(Indices<0,1>, Array<double,3,3>)
// arr is Array<double,3,3> as no contraction happens here, 0!=1
auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>)
// arr is Array<double,4> as 1st and 3rd indices contract 0==0
auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>)
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract
auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments
// requested but dimensions don't match 5!=6
// The parameters of Indices really do not matter as long as
// we can identify contractions. They are typically expressed as enums, I,J,K...
因此,从本质Idx...
,给定Idx...
和Dims...
的大小都应相等,请检查Idx...
中的哪些值相等,获取它们出现的位置并删除Dims...
中的相应条目(位置) Dims...
这本质上是张量收缩规则 。
数组收缩规则:
sizeof...(Idx)==sizeof...(Dims)
Idx
和Dims
即,如果我们具有Indices<0,1,2>
和Array<double,4,5,6>
0
映射到4
, 1
映射到5
和2
映射到6
。 Idx
存在相同/相等的值,则表示收缩,这意味着Dims
的相应尺寸应消失,例如,如果我们有Indices<0,0,3>
和Array<double,4,4,6>
,然后0==0
并且这些值映射到的相应维度4
和4
都需要消失,并且结果数组应为Array<double,6>
Idx
具有相同的值,但对应的Dims
不匹配,则应触发编译时错误,例如, Indices<0,0,3>
和Array<double,4,5,6>
不可能,因为4!=5
,类似地,由于4!=6
, Indices<0,1,0>
是不可能的,这导致 Array<double,4,5,6>
不能以任何方式收缩。 Dims
也匹配, Idx
就可以使用多对,三胞胎,四胞胎等,例如Indices<0,0,0,0,1,1,4,3,3,7,7,7>
如果输入数组为Array<double,2,2,2,2,3,3,6,2,2,3,3,3>
Indices<0,0,0,0,1,1,4,3,3,7,7,7>
会收缩为Array<double,6>
Array<double,2,2,2,2,3,3,6,2,2,3,3,3>
。 我对元编程的了解对实现此功能没有多大帮助,但是我希望我已经明确了意图,以便有人指导我朝正确的方向发展。
一堆执行实际检查的constexpr
函数:
// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
return cur == N ? true :
(cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}
// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
return cur == N ? true :
(ind[cur] != index || dim[cur] == dimension) ?
check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}
// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
return is_uniq(ind, i) ? dim[i] :
check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}
现在,我们需要一种摆脱-1
s的方法:
template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
: concat<Indices<I1..., I2...>, Inds...> {};
// filter out all instances of I from Is...,
// return the rest as an Indices
template<int I, int... Is>
struct filter
: concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};
使用它们:
template<class Ind, class Arr, class Seq>
struct contraction_impl;
template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
static constexpr int ind[] = { Ind... };
static constexpr int dim[] = { Dim... };
static constexpr int result[] = {calc(Seq, ind, dim)...};
template<int... Dims>
static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;
using type = decltype(unpack_helper(typename filter<-1, result[Seq]...>::type{}));
};
template<class T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>,
std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);
除了make_index_sequence
之外的所有东西都是C ++ 11。 您可以在SO上找到大量的实现。
这是一团糟,但我认为它可以完成您想要的操作。 几乎可以肯定,可以对此进行许多简化,但这是我通过测试的第一遍。 请注意,这不会实现收缩,而只是确定应为哪种类型。 如果那不是您所需要的,我事先表示歉意。
#include <type_traits>
template <std::size_t...>
struct Indices {};
template <typename, std::size_t...>
struct Array {};
// Count number of 'i' in 'rest...', base case
template <std::size_t i, std::size_t... rest>
struct Count : std::integral_constant<std::size_t, 0>
{};
// Count number of 'i' in 'rest...', inductive case
template <std::size_t i, std::size_t j, std::size_t... rest>
struct Count<i, j, rest...> :
std::integral_constant<std::size_t,
Count<i, rest...>::value + ((i == j) ? 1 : 0)>
{};
// Is 'i' contained in 'rest...'?
template <std::size_t i, std::size_t... rest>
struct Contains :
std::integral_constant<bool, (Count<i, rest...>::value > 0)>
{};
// Accumulation of counts of indices in all, base case
template <typename All, typename Remainder,
typename AccIdx, typename AccCount>
struct Counts {
using indices = AccIdx;
using counts = AccCount;
};
// Accumulation of counts of indices in all, inductive case
template <std::size_t... all, std::size_t i, std::size_t... rest,
std::size_t... indices, std::size_t... counts>
struct Counts<Indices<all...>, Indices<i, rest...>,
Indices<indices...>, Indices<counts...>>
: std::conditional<Contains<i, indices...>::value,
Counts<Indices<all...>, Indices<rest...>,
Indices<indices...>,
Indices<counts...>>,
Counts<Indices<all...>, Indices<rest...>,
Indices<indices..., i>,
Indices<counts...,
Count<i, all...>::value>>>::type
{};
// Get value in From that matched the first value of Idx that matched idx
template <std::size_t idx, typename Idx, typename From>
struct First : std::integral_constant<std::size_t, 0>
{};
template <std::size_t i, std::size_t j, std::size_t k,
std::size_t... indices, std::size_t... values>
struct First<i, Indices<j, indices...>, Indices<k, values...>>
: std::conditional<i == j,
std::integral_constant<std::size_t, k>,
First<i, Indices<indices...>,
Indices<values...>>>::type
{};
// Return whether all values in From that match Idx being idx are tgt
template <std::size_t idx, std::size_t tgt, typename Idx, typename From>
struct AllMatchTarget : std::true_type
{};
template <std::size_t idx, std::size_t tgt,
std::size_t i, std::size_t j,
std::size_t... indices, std::size_t... values>
struct AllMatchTarget<idx, tgt,
Indices<i, indices...>, Indices<j, values...>>
: std::conditional<i == idx && j != tgt, std::false_type,
AllMatchTarget<idx, tgt, Indices<indices...>,
Indices<values...>>>::type
{};
/* Generate the dimensions, given the counts, indices, and values */
template <typename Counts, typename Indices,
typename AllIndices, typename Values, typename Accum>
struct GenDims;
template <typename A, typename V, typename R>
struct GenDims<Indices<>, Indices<>, A, V, R> {
using type = R;
};
template <typename T, std::size_t i, std::size_t c,
std::size_t... counts, std::size_t... indices,
std::size_t... dims, typename AllIndices, typename Values>
struct GenDims<Indices<c, counts...>, Indices<i, indices...>,
AllIndices, Values, Array<T, dims...>>
{
static constexpr auto value = First<i, AllIndices, Values>::value;
static_assert(AllMatchTarget<i, value, AllIndices, Values>::value,
"Index doesn't correspond to matching dimensions");
using type = typename GenDims<
Indices<counts...>, Indices<indices...>,
AllIndices, Values,
typename std::conditional<c == 1,
Array<T, dims..., value>,
Array<T, dims...>>::type>::type;
};
/* Put it all together */
template <typename I, typename A>
struct ContractionType;
template <typename T, std::size_t... indices, std::size_t... values>
struct ContractionType<Indices<indices...>, Array<T, values...>> {
static_assert(sizeof...(indices) == sizeof...(values),
"Number of indices and dimensions do not match");
using counts = Counts<Indices<indices...>,
Indices<indices...>,
Indices<>, Indices<>>;
using type = typename GenDims<typename counts::counts,
typename counts::indices,
Indices<indices...>, Indices<values...>,
Array<T>>::type;
};
static_assert(std::is_same<typename
ContractionType<Indices<0, 0>, Array<double, 3, 3>>::type,
Array<double>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1>, Array<double, 3, 3>>::type,
Array<double, 3, 3>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1, 0>, Array<double, 3, 4, 3>>::type,
Array<double, 4>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1, 0, 7, 7, 2>,
Array<double, 3, 4, 3, 5, 5, 6>>::type,
Array<double, 4, 6>>::value, "");
// Errors appropriately when uncommented
/* static_assert(std::is_same<typename */
/* ContractionType<Indices<10,10, 2, 3>, */
/* Array<double, 5,6,4,4>>::type, */
/* Array<double>::value, ""); */
以下是对这里发生的情况的解释:
Counts
生成唯一索引的列表( Counts::indices
)以及每个索引出现在序列中的次数( Counts::counts
)。 Counts
对数进行Counts
,对于每个索引,如果计数为1,我将累加值并递归。 否则,我将累加值继续传递并递归。 最GenDims
部分是GenDims
的static_assert
,它会验证索引是否所有匹配的尺寸都相同。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.