简体   繁体   中英

Concisely declare and initialize a multi-dimensional array in C++

For example in 3 dimensions, I would normally do something like

vector<vector<vector<T>>> v(x, vector<vector<T>>(y, vector<T>(z, val)));

However this gets tedious for complex types and in large dimensions. Is it possible to define a type, say, tensor , whose usage would be like so:

tensor<T> t(x, y, z, val1);
t[i][j][k] = val2;

It's possible with template metaprogramming.

Define a vector NVector

template<int D, typename T>
struct NVector : public vector<NVector<D - 1, T>> {
    template<typename... Args>
    NVector(int n = 0, Args... args) : vector<NVector<D - 1, T>>(n, NVector<D - 1, T>(args...)) {
    }
};

template<typename T>
struct NVector<1, T> : public vector<T> {
    NVector(int n = 0, const T &val = T()) : vector<T>(n, val) {
    }
};

You can use it like this

    const int n = 5, m = 5, k = 5;
    NVector<3, int> a(n, m, k, 0);
    cout << a[0][0][0] << '\n';

I think it's clear how it can be used. Let's still say NVector<# of dimensions, type> a(lengths of each dimension separated by coma (optional)..., default value (optional)) .

The other answer shows a good way of making a vector of vectors with template metaprogramming. If you want a multidimensional array data structure with fewer allocations and contiguous storage underneath, here is an example of how to achieve that with a an NDArray template class wrapping access to an underlying vector. This could be extended to define operator= , copy operators, debug bounds checking per dimension, etc for extra convenience.

NDArray.h

#pragma once

#include <array>
#include <vector>

template<int N, typename ValueType>
class NDArray {
public:
    template<typename... Args>
    NDArray(Args... args)
    : dims({{static_cast<int>(args)...}}),
      offsets(compute_offsets(dims)),
      data(compute_size(dims), ValueType{})
    {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray dimension arguments");
    }

    void fill(ValueType val) {
        std::fill(data.begin(), data.end(), val);
    }

    template<typename... Args>
    inline ValueType operator()(Args... args) const {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray index arguments");
        return data[calc_index({ {static_cast<int>(args)...} })];
    }

    template<typename... Args>
    inline ValueType& operator()(Args... args) {
        static_assert(sizeof...(args) == N, 
            "Incorrect number of NDArray index arguments");
        return data[calc_index({ {static_cast<int>(args)...} })];
    }

    int length(int axis) const { return dims[axis]; }

    const int num_dims = N;

private:
    static std::array<int, N> compute_offsets(const std::array<int, N>& dims) {
        std::array<int, N> offsets{};
        offsets[0] = 1;
        for (int i = 1; i < N; ++i) {
            offsets[i] = offsets[i - 1] * dims[i - 1];
        }
        return offsets;
    }

    static int compute_size(const std::array<int, N>& dims) {
        int size = 1;
        for (auto&& d : dims) size *= d;
        return size;
    }

    inline int calc_index(const std::array<int, N>& indices) const {
        int idx = 0;
        for (int i = 0; i < N; ++i) idx += offsets[i] * indices[i];
        return idx;
    }

    const std::array<int, N> dims;
    const std::array<int, N> offsets;
    std::vector<ValueType> data;
};

This overrides the operator() with the correct number of arguments, and won't compile if the wrong number of arguments is given. Some example use

using Array2D = NDArray<2,double>;
using Array3D = NDArray<3,double>;

auto a = Array2D(3, 6);
a.fill(1.0);
a(2, 4) = 2.0;
//a(2,4,4) will not compile
std::cout << "a = " << std::endl << a << std::endl;

//auto b = Array3D(4, 4); // will not compile

auto b = Array3D(4, 3, 2);
b.fill(-1.0);
b(0, 0, 0) = 4.0;
b(1, 1, 1) = 2.0;
std::cout << "b = " << std::endl << b << std::endl;

(using helper output methods for 2D and 3D arrays)

std::ostream& operator<<(std::ostream& os, const Array2D& arr) {
    for (int i = 0; i < arr.length(0); ++i) {
        for (int j = 0; j < arr.length(1); ++j) {
            os << arr(i,j) << " ";
        }
        os << std::endl;
    }
    return os;
}

std::ostream& operator<<(std::ostream& os, const Array3D& arr) {
    for (int k = 0; k < arr.length(2); ++k) {
        os << "array(:,:,"<<k<<") = " << std::endl;
        for (int i = 0; i < arr.length(0); ++i) {
            os << "  ";
            for (int j = 0; j < arr.length(1); ++j) {
                os << arr(i, j, k) << " ";
            }
            os << std::endl;
        }
        os << std::endl;
    }
    return os;
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM