簡體   English   中英

擴展 std::index_sequence 和可變參數包時參數包長度不匹配

[英]Mismatched argument pack lengths while expanding std::index_sequence and variadic argument pack

我一直在研究一個小型的多維張量數組實現; 並且遇到了std::make_index_sequence與可變參數模板參數結合的問題。 給出以下剝離的實現:

template <class scalar_t, std::size_t ... Dims>
class tensor {
public:
    using value_type = scalar_t;
    using size_type = std::size_t;
    using index_type = std::size_t;
    using container_type = std::array<value_type, (Dims * ...)>;
    using shape_type = std::array<size_type, sizeof...(Dims)>;
    using stride_type = std::array<index_type, sizeof...(Dims)>;

    constexpr static inline size_type size = (Dims * ...);
    constexpr static inline size_type rank = sizeof...(Dims);
    constexpr static inline shape_type shape = /* omitted for brevity */;
    constexpr static inline stride_type stride = /* omitted for brevity */;

    /* constructors omitted */

private:
    container_type m_data{};

我現在想在調用運算符operator()()上有一個可變模板來訪問私有容器中的元素:

template <class ... Indices, std::enable_if_t<sizeof...(Indices) == rank, int> = 0>
[[nodiscard]] constexpr value_type &operator()(Indices && ... index) noexcept {
    const index_type data_index = resolve_index(std::make_index_sequence<rank>(), std::forward<Indices>(index)...);
    return m_data[data_index];
}

這個想法很簡單。 為張量的rank生成一個index_sequence ,然后將請求轉發給一個私有助手resolve_index ,它根據跨步內存布局解析索引。 這是一個有效的解決方案


工作解決方案

template <class Indices>
[[nodiscard]] constexpr index_type resolve_index(const std::size_t axis, Indices && index) noexcept {
    return index * strides[axis];
}

template <std::size_t ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(std::index_sequence<Axes...>, Indices && ... index) noexcept {
    return (resolve_index(Axes, std::forward<Indices>(index)) + ...);
}

resolve_index超載。 我確信編譯器可以在編譯時擴展折疊表達式,但是,每次調用resolve_index(std::size_t, Indices&&)都會在運行時執行(這很好)。

我對這個解決方案的不滿是,我通常更喜歡盡可能使用if constexpr (...)來消除上面的簡單函數重載; 特別是如果函數的返回不必由autodecltype(auto)自動推導。


因此,我想寫如下內容:

template <class ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(Axes && ... axis, Indices && ... index) noexcept {
    if constexpr (sizeof...(Indices) == 1)
        return (index + ...) * strides[(axis + ...)];
    else
        return (resolve_index(std::forward<Axes>(axis), std::forward<Indices>(index)) + ...);
}

不幸的是,在編譯時出錯了:

error: mismatched argument pack lengths while expanding ‘((tecra::tensor<scalar_t, Dims>*)this)->tecra::tensor<scalar_t, Dims>::resolve_index(forward<Axes>(axis), forward<Indices>(index))’
  108 |                 return (resolve_index(std::forward<Axes>(axis), std::forward<Indices>(index)) + ...);
      |                                                                                                    ^

我哪里做錯了? 這是一個有效的 godbolt 示例: https ://godbolt.org/z/qsY51n8f7(隨意忽略internal內容)。 感謝任何對此進行調查的人!

template <class ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type
resolve_index(Axes&& ... axis, Indices&& ... index) noexcept;

有兩個問題:

  • Axes&&...不可推導(不是最后一個參數)。
  • Axes&&...std::size_tstd::index_sequence<Is...> (所以它應該只是Axe&& )但是你不能根據沒有助手(函數或 lambda)的Is展開。

它會是這樣的:

template <class Axe, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(Axe axe, Indices && ... index) noexcept {
    if constexpr (std::is_same_v<Axe, std::size_t>) {
        static_assert(sizeof...(Indices) == 1);
        return (index + ...) * strides[axe];
    } else {
        return [&]<std::size_t...Is>(std::index_sequence<Is...>){
            static_assert(sizeof...(Indices) == sizeof...(Is));

            return (resolve_index(Is, std::forward<Indices>(index)) + ...);
            }(axe);
    }
}

演示

我建議直接在類參數中添加序列:

template <class scalar_t, typename seq_dim, std::size_t ... Dims>
class tensor_impl;
    
template <class scalar_t, std::size_t... Is, std::size_t ... Dims>
class tensor_impl<scalar_t, std::index_sequence<Is...>, Dims...>
{
// You might directly use Is...
// simplifying your interface (you might get rid of some template)
// ...
};

template <class scalar_t, std::size_t ... Dims>
using tensor = tensor_impl<scalar_t, std::index_sequence_for<Dims...>, Dims...>;

演示

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM