25 StrideIterator(T *ptr,
long stride) : m_ptr(ptr), m_stride(stride) {}
27 T &operator*() {
return *m_ptr; }
29 const T &operator*()
const {
return *m_ptr; }
32 return m_ptr == other.m_ptr;
35 return m_ptr < other.m_ptr;
58 m_ptr += n * m_stride;
63 m_ptr -= n * m_stride;
68 auto out_iter = *
this;
73 auto out_iter = *
this;
88 const T &operator*()
const {
return *m_ptr; }
91 return m_ptr == other.m_ptr;
94 return m_ptr < other.m_ptr;
117 m_ptr += n * m_stride;
122 m_ptr -= n * m_stride;
127 auto out_iter = *
this;
128 return out_iter += n;
132 auto out_iter = *
this;
133 return out_iter -= n;
144template <
typename T =
double>
149 std::size_t m_stride;
153 ArrayView(
T *data, std::size_t size, std::size_t stride = 1)
154 : m_size(size), m_stride(stride), m_data(data) {}
156 std::size_t size()
const {
return m_size; }
158 T &operator[](std::size_t
i) {
return m_data[
i * m_stride]; }
160 T operator[](std::size_t
i)
const {
return m_data[
i * m_stride]; }
162 T &at(std::size_t
i) {
164 return m_data[
i * m_stride];
167 T at(std::size_t
i)
const {
169 return m_data[
i * m_stride];
172 T &operator()(std::size_t
i) {
return at(
i); }
173 T operator()(std::size_t
i)
const {
return at(
i); }
175 T &front() {
return at(0); }
176 T front()
const {
return at(0); }
177 T &back() {
return at(m_size - 1); }
178 T back()
const {
return at(m_size - 1); }
180 T *data() {
return m_data; }
187 return StrideIterator(m_data +
long(m_size * m_stride),
long(m_stride));
190 return ConstStrideIterator(m_data +
long(m_size * m_stride),
195 return StrideIterator(m_data +
long(m_size * m_stride) -
long(m_stride),
198 auto crbegin()
const {
199 return ConstStrideIterator(
200 m_data +
long(m_size * m_stride) -
long(m_stride), -
long(m_stride));
204 return StrideIterator(m_data -
long(m_stride), -
long(m_stride));
207 return ConstStrideIterator(m_data -
long(m_stride), -
long(m_stride));
213 for (std::size_t
i = 0;
i < m_size; ++
i) {
214 out.push_back(m_data[
i * m_stride]);
223template <
typename T,
typename... Args>
225 if constexpr (
sizeof...(rest) == 0) {
228 return first *
product(rest...);
232template <std::
size_t N>
233void NDrange_impl(std::vector<std::array<std::size_t, N>> &result,
234 std::array<std::size_t, N> ¤t,
235 const std::array<std::size_t, N> &maxValues,
238 result.push_back(current);
242 for (std::size_t i = 0; i < maxValues[index]; ++i) {
244 NDrange_impl<N>(result, current, maxValues, index + 1);
256template <
typename... Args>
257auto NDrange(std::size_t first, Args... rest) {
258 constexpr std::size_t N =
sizeof...(rest) + 1;
260 const std::array<std::size_t, N> maxValues = {
261 first,
static_cast<std::size_t
>(rest)...};
263 std::vector<std::array<std::size_t, N>> result;
264 result.reserve(
product(first,
static_cast<std::size_t
>(rest)...));
266 std::array<std::size_t, N> current = {0};
268 NDrange_impl<N>(result, current, maxValues, 0);
273template <
typename T =
double>
274class Array :
public Arithmetic<Array<T>>, Arithmetic2<Array<T>, T> {
278 std::vector<std::size_t> m_sizes;
282 std::vector<std::size_t> m_cumulative_sizes;
284 std::size_t m_total_size;
286 std::vector<T> m_data;
291 template <
typename... Args>
292 Array(std::size_t first, Args... rest);
295 template <
typename... Args>
296 void resize(std::size_t first, Args... rest);
299 std::size_t size()
const {
return m_total_size; }
302 std::size_t size(std::size_t dim)
const {
return m_sizes.at(dim); }
305 std::size_t dimensions()
const {
return m_Ndim; }
308 const std::vector<std::size_t> &shape()
const {
return m_sizes; }
311 template <
typename... Args>
312 T &at(std::size_t first, Args... rest);
315 template <
typename... Args>
316 T at(std::size_t first, Args... rest)
const;
319 template <
typename... Args>
320 T &operator()(std::size_t first, Args... rest);
323 template <
typename... Args>
324 T operator()(std::size_t first, Args... rest)
const;
328 T *data() {
return m_data.data(); }
329 const T *data()
const {
return m_data.data(); }
332 const std::vector<T> &vector()
const {
return m_data; }
335 std::size_t rows()
const {
return size(0); }
337 std::size_t cols()
const {
return size(1); }
340 ArrayView<T> row(std::size_t i);
341 ArrayView<const T> row(std::size_t i)
const;
343 ArrayView<T> col(std::size_t j);
344 ArrayView<const T> col(std::size_t j)
const;
347 auto begin() {
return m_data.begin(); }
349 auto cbegin()
const {
return m_data.cbegin(); }
351 auto end() {
return m_data.end(); }
353 auto cend()
const {
return m_data.cend(); }
356 auto rbegin() {
return m_data.rbegin(); }
358 auto crbegin()
const {
return m_data.crbegin(); }
360 auto rend() {
return m_data.rend(); }
362 auto crend()
const {
return m_data.crend(); }
367 Array<T> &operator-=(
const Array<T> &other);
368 Array<T> &operator*=(
const Array<T> &other);
369 Array<T> &operator/=(
const Array<T> &other);
376 Array<T> &operator-=(
const T &t);
377 Array<T> &operator*=(
const T &t);
378 Array<T> &operator/=(
const T &t);
382 std::vector<std::size_t> calc_cumulative_size()
const;
385 template <
typename... Args>
386 std::size_t unchecked_index(std::size_t first, Args... rest)
const;
389 template <
typename... Args>
390 std::size_t unchecked_index_impl(std::size_t dim, std::size_t first,
394 template <
typename... Args>
395 std::size_t checked_index(std::size_t first, Args... rest)
const;
398 template <
typename... Args>
399 std::size_t checked_index_impl(std::size_t dim, std::size_t first,
407std::vector<std::size_t> Array<T>::calc_cumulative_size()
const {
408 std::vector<std::size_t> cumulative_size;
409 cumulative_size.reserve(m_sizes.size());
410 for (std::size_t i = 0; i < m_Ndim; ++i) {
411 cumulative_size.push_back(std::accumulate(m_sizes.cbegin() +
long(i) + 1,
413 std::multiplies<std::size_t>()));
415 return cumulative_size;
419template <
typename... Args>
420Array<T>::Array(std::size_t first, Args... rest)
421 : m_sizes({first,
static_cast<std::size_t
>(rest)...}),
422 m_Ndim(m_sizes.size()),
423 m_cumulative_sizes(calc_cumulative_size()),
424 m_total_size(std::accumulate(m_sizes.cbegin(), m_sizes.cend(), 1ul,
425 std::multiplies<std::size_t>())),
426 m_data(m_total_size) {}
429template <
typename... Args>
430void Array<T>::resize(std::size_t first, Args... rest) {
431 m_sizes = std::vector{first,
static_cast<std::size_t
>(rest)...};
432 m_Ndim = m_sizes.size();
433 m_cumulative_sizes = calc_cumulative_size();
434 m_total_size = std::accumulate(m_sizes.cbegin(), m_sizes.cend(), 1ul,
435 std::multiplies<std::size_t>());
436 m_data.resize(m_total_size);
440template <
typename... Args>
441std::size_t Array<T>::unchecked_index(std::size_t first, Args... rest)
const {
442 return unchecked_index_impl(0, first, rest...);
446template <
typename... Args>
447std::size_t Array<T>::unchecked_index_impl(std::size_t dim, std::size_t first,
448 Args... rest)
const {
449 if constexpr (
sizeof...(rest) == 0) {
450 return first * m_cumulative_sizes[dim];
452 return first * m_cumulative_sizes[dim] +
453 unchecked_index_impl(dim + 1, rest...);
458template <
typename... Args>
459std::size_t Array<T>::checked_index(std::size_t first, Args... rest)
const {
460 assert(
sizeof...(rest) + 1 == m_Ndim &&
461 "Number of arguments must match number of dimensions");
462 return checked_index_impl(0, first, rest...);
466template <
typename... Args>
467std::size_t Array<T>::checked_index_impl(std::size_t dim, std::size_t first,
468 Args... rest)
const {
469 assert(first < m_sizes[dim]);
470 if constexpr (
sizeof...(rest) == 0) {
471 return first * m_cumulative_sizes[dim];
473 return first * m_cumulative_sizes[dim] +
474 checked_index_impl(dim + 1, rest...);
479template <
typename... Args>
480T &Array<T>::at(std::size_t first, Args... rest) {
481 return m_data.at(checked_index(first, rest...));
485template <
typename... Args>
486T Array<T>::at(std::size_t first, Args... rest)
const {
487 return m_data.at(checked_index(first, rest...));
491template <
typename... Args>
492T &Array<T>::operator()(std::size_t first, Args... rest) {
493 return m_data[unchecked_index(first, rest...)];
497template <
typename... Args>
498T Array<T>::operator()(std::size_t first, Args... rest)
const {
499 return m_data[unchecked_index(first, rest...)];
503Array<T> &Array<T>::operator+=(
const Array<T> &other) {
504 assert(m_sizes == other.m_sizes &&
505 "Arithmetic only defined for equal-dimension arrays");
506 for (std::size_t i = 0; i < m_data.size(); ++i) {
507 this->m_data[i] += other.m_data[i];
513Array<T> &Array<T>::operator-=(
const Array<T> &other) {
514 assert(m_sizes == other.m_sizes &&
515 "Arithmetic only defined for equal-dimension arrays");
516 for (std::size_t i = 0; i < m_data.size(); ++i) {
517 this->m_data[i] -= other.m_data[i];
523Array<T> &Array<T>::operator*=(
const Array<T> &other) {
524 assert(m_sizes == other.m_sizes &&
525 "Arithmetic only defined for equal-dimension arrays");
526 for (std::size_t i = 0; i < m_data.size(); ++i) {
527 this->m_data[i] *= other.m_data[i];
533Array<T> &Array<T>::operator/=(
const Array<T> &other) {
534 assert(m_sizes == other.m_sizes &&
535 "Arithmetic only defined for equal-dimension arrays");
536 for (std::size_t i = 0; i < m_data.size(); ++i) {
537 this->m_data[i] /= other.m_data[i];
543Array<T> &Array<T>::operator+=(
const T &t) {
544 for (std::size_t i = 0; i < m_data.size(); ++i) {
545 this->m_data[i] += t;
551Array<T> &Array<T>::operator-=(
const T &t) {
552 for (std::size_t i = 0; i < m_data.size(); ++i) {
553 this->m_data[i] -= t;
559Array<T> &Array<T>::operator*=(
const T &t) {
560 for (std::size_t i = 0; i < m_data.size(); ++i) {
561 this->m_data[i] *= t;
567Array<T> &Array<T>::operator/=(
const T &t) {
568 for (std::size_t i = 0; i < m_data.size(); ++i) {
569 this->m_data[i] /= t;
575ArrayView<T> Array<T>::row(std::size_t i) {
576 assert(m_Ndim == 2 &&
"Row only defined for 2D array");
577 return ArrayView(m_data.data() + i * m_sizes[1], m_sizes[1]);
580ArrayView<T> Array<T>::col(std::size_t j) {
581 assert(m_Ndim == 2 &&
"Col only defined for 2D array");
582 return ArrayView(m_data.data() + j, m_sizes[0], m_sizes[1]);
585ArrayView<const T> Array<T>::row(std::size_t i)
const {
586 assert(m_Ndim == 2 &&
"Row only defined for 2D array");
587 return ArrayView(m_data.data() + i * m_sizes[1], m_sizes[1]);
590ArrayView<const T> Array<T>::col(std::size_t j)
const {
591 assert(m_Ndim == 2 &&
"Col only defined for 2D array");
592 return ArrayView<const T>(m_data.data() + j, m_sizes[0], m_sizes[1]);