ampsci
c++ program for high-precision atomic structure calculations of single-valence systems
Loading...
Searching...
No Matches
Array.hpp
1#pragma once
2#include "Template.hpp"
3#include <array>
4#include <cassert>
5#include <numeric>
6#include <vector>
7
9namespace qip {
10
11//==============================================================================
13
18template <typename T>
19class StrideIterator : public qip::Comparison<StrideIterator<T>> {
20protected:
21 T *m_ptr;
22 long m_stride;
23
24public:
25 StrideIterator(T *ptr, long stride) : m_ptr(ptr), m_stride(stride) {}
26
27 T &operator*() { return *m_ptr; }
28
29 const T &operator*() const { return *m_ptr; }
30
31 bool operator==(const StrideIterator &other) const {
32 return m_ptr == other.m_ptr;
33 }
34 bool operator<(const StrideIterator &other) const {
35 return m_ptr < other.m_ptr;
36 }
37
38 StrideIterator &operator++() {
39 m_ptr += m_stride;
40 return *this;
41 }
42 StrideIterator &operator--() {
43 m_ptr -= m_stride;
44 return *this;
45 }
46 StrideIterator operator++(int) {
47 auto temp = *this;
48 ++*this;
49 return temp;
50 }
51 StrideIterator operator--(int) {
52 auto temp = *this;
53 --*this;
54 return temp;
55 }
56
57 StrideIterator &operator+=(long n) {
58 m_ptr += n * m_stride;
59 return *this;
60 }
61
62 StrideIterator &operator-=(long n) {
63 m_ptr -= n * m_stride;
64 return *this;
65 }
66
67 StrideIterator operator+(long n) const {
68 auto out_iter = *this;
69 return out_iter += n;
70 }
71
72 StrideIterator operator-(long n) const {
73 auto out_iter = *this;
74 return out_iter -= n;
75 }
76};
77
79template <typename T>
80class ConstStrideIterator : public qip::Comparison<ConstStrideIterator<T>> {
81protected:
82 T *m_ptr;
83 long m_stride;
84
85public:
86 ConstStrideIterator(T *ptr, long stride) : m_ptr(ptr), m_stride(stride) {}
87
88 const T &operator*() const { return *m_ptr; }
89
90 bool operator==(const ConstStrideIterator &other) const {
91 return m_ptr == other.m_ptr;
92 }
93 bool operator<(const ConstStrideIterator &other) const {
94 return m_ptr < other.m_ptr;
95 }
96
97 ConstStrideIterator &operator++() {
98 m_ptr += m_stride;
99 return *this;
100 }
101 ConstStrideIterator &operator--() {
102 m_ptr -= m_stride;
103 return *this;
104 }
105 ConstStrideIterator operator++(int) {
106 auto temp = *this;
107 ++*this;
108 return temp;
109 }
110 ConstStrideIterator operator--(int) {
111 auto temp = *this;
112 --*this;
113 return temp;
114 }
115
116 ConstStrideIterator &operator+=(long n) {
117 m_ptr += n * m_stride;
118 return *this;
119 }
120
121 ConstStrideIterator &operator-=(long n) {
122 m_ptr -= n * m_stride;
123 return *this;
124 }
125
126 ConstStrideIterator operator+(long n) const {
127 auto out_iter = *this;
128 return out_iter += n;
129 }
130
131 ConstStrideIterator operator-(long n) const {
132 auto out_iter = *this;
133 return out_iter -= n;
134 }
135};
136
137//==============================================================================
139
144template <typename T = double>
146
147private:
148 std::size_t m_size; // number of elements
149 std::size_t m_stride;
150 T *m_data;
151
152public:
153 ArrayView(T *data, std::size_t size, std::size_t stride = 1)
154 : m_size(size), m_stride(stride), m_data(data) {}
155
156 std::size_t size() const { return m_size; }
157
158 T &operator[](std::size_t i) { return m_data[i * m_stride]; }
159
160 T operator[](std::size_t i) const { return m_data[i * m_stride]; }
161
162 T &at(std::size_t i) {
163 assert(i < m_size);
164 return m_data[i * m_stride];
165 }
166
167 T at(std::size_t i) const {
168 assert(i < m_size);
169 return m_data[i * m_stride];
170 }
171
172 T &operator()(std::size_t i) { return at(i); }
173 T operator()(std::size_t i) const { return at(i); }
174
175 T &front() { return at(0); }
176 T front() const { return at(0); }
177 T &back() { return at(m_size - 1); }
178 T back() const { return at(m_size - 1); }
179
180 T *data() { return m_data; }
181
183 auto begin() { return StrideIterator(m_data, long(m_stride)); }
184 auto cbegin() const { return ConstStrideIterator(m_data, long(m_stride)); }
185
186 auto end() {
187 return StrideIterator(m_data + long(m_size * m_stride), long(m_stride));
188 }
189 auto cend() const {
190 return ConstStrideIterator(m_data + long(m_size * m_stride),
191 long(m_stride));
192 }
193
194 auto rbegin() {
195 return StrideIterator(m_data + long(m_size * m_stride) - long(m_stride),
196 -long(m_stride));
197 }
198 auto crbegin() const {
199 return ConstStrideIterator(
200 m_data + long(m_size * m_stride) - long(m_stride), -long(m_stride));
201 }
202
203 auto rend() {
204 return StrideIterator(m_data - long(m_stride), -long(m_stride));
205 }
206 auto crend() const {
207 return ConstStrideIterator(m_data - long(m_stride), -long(m_stride));
208 }
209
211 std::vector<T> vector() {
212 std::vector<T> out;
213 for (std::size_t i = 0; i < m_size; ++i) {
214 out.push_back(m_data[i * m_stride]);
215 }
216 return out;
217 }
218};
219
220//==============================================================================
221
223template <typename T, typename... Args>
224T product(T first, Args... rest) {
225 if constexpr (sizeof...(rest) == 0) {
226 return first;
227 } else {
228 return first * product(rest...);
229 }
230}
231
232template <std::size_t N>
233void NDrange_impl(std::vector<std::array<std::size_t, N>> &result,
234 std::array<std::size_t, N> &current,
235 const std::array<std::size_t, N> &maxValues,
236 std::size_t index) {
237 if (index == N) {
238 result.push_back(current);
239 return;
240 }
241
242 for (std::size_t i = 0; i < maxValues[index]; ++i) {
243 current[index] = i;
244 NDrange_impl<N>(result, current, maxValues, index + 1);
245 }
246}
247
249
256template <typename... Args>
257auto NDrange(std::size_t first, Args... rest) {
258 constexpr std::size_t N = sizeof...(rest) + 1;
259
260 const std::array<std::size_t, N> maxValues = {
261 first, static_cast<std::size_t>(rest)...};
262
263 std::vector<std::array<std::size_t, N>> result;
264 result.reserve(product(first, static_cast<std::size_t>(rest)...));
265
266 std::array<std::size_t, N> current = {0};
267
268 NDrange_impl<N>(result, current, maxValues, 0);
269 return result;
270}
271
272//==============================================================================
273template <typename T = double>
274class Array : public Arithmetic<Array<T>>, Arithmetic2<Array<T>, T> {
275
276private:
277 // List of sizes for each array dimension
278 std::vector<std::size_t> m_sizes;
279 // Number of array dimensions
280 std::size_t m_Ndim;
281 // Cumulative sizes (used to index into data)
282 std::vector<std::size_t> m_cumulative_sizes;
283 // Total number of elements
284 std::size_t m_total_size;
285 // Raw data
286 std::vector<T> m_data;
287
288public:
291 template <typename... Args>
292 Array(std::size_t first, Args... rest);
293
295 template <typename... Args>
296 void resize(std::size_t first, Args... rest);
297
299 std::size_t size() const { return m_total_size; }
300
302 std::size_t size(std::size_t dim) const { return m_sizes.at(dim); }
303
305 std::size_t dimensions() const { return m_Ndim; }
306
308 const std::vector<std::size_t> &shape() const { return m_sizes; }
309
311 template <typename... Args>
312 T &at(std::size_t first, Args... rest);
313
315 template <typename... Args>
316 T at(std::size_t first, Args... rest) const;
317
319 template <typename... Args>
320 T &operator()(std::size_t first, Args... rest);
321
323 template <typename... Args>
324 T operator()(std::size_t first, Args... rest) const;
325
328 T *data() { return m_data.data(); }
329 const T *data() const { return m_data.data(); }
330
332 const std::vector<T> &vector() const { return m_data; }
333
335 std::size_t rows() const { return size(0); }
337 std::size_t cols() const { return size(1); }
338
340 ArrayView<T> row(std::size_t i);
341 ArrayView<const T> row(std::size_t i) const;
343 ArrayView<T> col(std::size_t j);
344 ArrayView<const T> col(std::size_t j) const;
345
347 auto begin() { return m_data.begin(); }
349 auto cbegin() const { return m_data.cbegin(); }
351 auto end() { return m_data.end(); }
353 auto cend() const { return m_data.cend(); }
354
356 auto rbegin() { return m_data.rbegin(); }
358 auto crbegin() const { return m_data.crbegin(); }
360 auto rend() { return m_data.rend(); }
362 auto crend() const { return m_data.crend(); }
363
366 Array<T> &operator+=(const Array<T> &other);
367 Array<T> &operator-=(const Array<T> &other);
368 Array<T> &operator*=(const Array<T> &other);
369 Array<T> &operator/=(const Array<T> &other);
370
372
375 Array<T> &operator+=(const T &t);
376 Array<T> &operator-=(const T &t);
377 Array<T> &operator*=(const T &t);
378 Array<T> &operator/=(const T &t);
379
380private:
381 // Calculates the cumulative_sizes array
382 std::vector<std::size_t> calc_cumulative_size() const;
383
384 // Unchecked index calculation
385 template <typename... Args>
386 std::size_t unchecked_index(std::size_t first, Args... rest) const;
387
388 // Helper for unchecked index calculation
389 template <typename... Args>
390 std::size_t unchecked_index_impl(std::size_t dim, std::size_t first,
391 Args... rest) const;
392
393 // Checked index calculation
394 template <typename... Args>
395 std::size_t checked_index(std::size_t first, Args... rest) const;
396
397 // Helper for checked index calculation
398 template <typename... Args>
399 std::size_t checked_index_impl(std::size_t dim, std::size_t first,
400 Args... rest) const;
401};
402
403//==============================================================================
404// Implementations
405//==============================================================================
406template <typename T>
407std::vector<std::size_t> Array<T>::calc_cumulative_size() const {
408 std::vector<std::size_t> cumulative_size;
409 cumulative_size.reserve(m_sizes.size());
410 for (std::size_t i = 0; i < m_Ndim; ++i) {
411 cumulative_size.push_back(std::accumulate(m_sizes.cbegin() + long(i) + 1,
412 m_sizes.cend(), 1ul,
413 std::multiplies<std::size_t>()));
414 }
415 return cumulative_size;
416}
417
418template <typename T>
419template <typename... Args>
420Array<T>::Array(std::size_t first, Args... rest)
421 : m_sizes({first, static_cast<std::size_t>(rest)...}),
422 m_Ndim(m_sizes.size()),
423 m_cumulative_sizes(calc_cumulative_size()),
424 m_total_size(std::accumulate(m_sizes.cbegin(), m_sizes.cend(), 1ul,
425 std::multiplies<std::size_t>())),
426 m_data(m_total_size) {}
427
428template <typename T>
429template <typename... Args>
430void Array<T>::resize(std::size_t first, Args... rest) {
431 m_sizes = std::vector{first, static_cast<std::size_t>(rest)...};
432 m_Ndim = m_sizes.size();
433 m_cumulative_sizes = calc_cumulative_size();
434 m_total_size = std::accumulate(m_sizes.cbegin(), m_sizes.cend(), 1ul,
435 std::multiplies<std::size_t>());
436 m_data.resize(m_total_size);
437}
438
439template <typename T>
440template <typename... Args>
441std::size_t Array<T>::unchecked_index(std::size_t first, Args... rest) const {
442 return unchecked_index_impl(0, first, rest...);
443}
444
445template <typename T>
446template <typename... Args>
447std::size_t Array<T>::unchecked_index_impl(std::size_t dim, std::size_t first,
448 Args... rest) const {
449 if constexpr (sizeof...(rest) == 0) {
450 return first * m_cumulative_sizes[dim];
451 } else {
452 return first * m_cumulative_sizes[dim] +
453 unchecked_index_impl(dim + 1, rest...);
454 }
455}
456
457template <typename T>
458template <typename... Args>
459std::size_t Array<T>::checked_index(std::size_t first, Args... rest) const {
460 assert(sizeof...(rest) + 1 == m_Ndim &&
461 "Number of arguments must match number of dimensions");
462 return checked_index_impl(0, first, rest...);
463}
464
465template <typename T>
466template <typename... Args>
467std::size_t Array<T>::checked_index_impl(std::size_t dim, std::size_t first,
468 Args... rest) const {
469 assert(first < m_sizes[dim]);
470 if constexpr (sizeof...(rest) == 0) {
471 return first * m_cumulative_sizes[dim];
472 } else {
473 return first * m_cumulative_sizes[dim] +
474 checked_index_impl(dim + 1, rest...);
475 }
476}
477
478template <typename T>
479template <typename... Args>
480T &Array<T>::at(std::size_t first, Args... rest) {
481 return m_data.at(checked_index(first, rest...));
482}
483
484template <typename T>
485template <typename... Args>
486T Array<T>::at(std::size_t first, Args... rest) const {
487 return m_data.at(checked_index(first, rest...));
488}
489
490template <typename T>
491template <typename... Args>
492T &Array<T>::operator()(std::size_t first, Args... rest) {
493 return m_data[unchecked_index(first, rest...)];
494}
495
496template <typename T>
497template <typename... Args>
498T Array<T>::operator()(std::size_t first, Args... rest) const {
499 return m_data[unchecked_index(first, rest...)];
500}
501
502template <typename T>
503Array<T> &Array<T>::operator+=(const Array<T> &other) {
504 assert(m_sizes == other.m_sizes &&
505 "Arithmetic only defined for equal-dimension arrays");
506 for (std::size_t i = 0; i < m_data.size(); ++i) {
507 this->m_data[i] += other.m_data[i];
508 }
509 return *this;
510}
511
512template <typename T>
513Array<T> &Array<T>::operator-=(const Array<T> &other) {
514 assert(m_sizes == other.m_sizes &&
515 "Arithmetic only defined for equal-dimension arrays");
516 for (std::size_t i = 0; i < m_data.size(); ++i) {
517 this->m_data[i] -= other.m_data[i];
518 }
519 return *this;
520}
521
522template <typename T>
523Array<T> &Array<T>::operator*=(const Array<T> &other) {
524 assert(m_sizes == other.m_sizes &&
525 "Arithmetic only defined for equal-dimension arrays");
526 for (std::size_t i = 0; i < m_data.size(); ++i) {
527 this->m_data[i] *= other.m_data[i];
528 }
529 return *this;
530}
531
532template <typename T>
533Array<T> &Array<T>::operator/=(const Array<T> &other) {
534 assert(m_sizes == other.m_sizes &&
535 "Arithmetic only defined for equal-dimension arrays");
536 for (std::size_t i = 0; i < m_data.size(); ++i) {
537 this->m_data[i] /= other.m_data[i];
538 }
539 return *this;
540}
541
542template <typename T>
543Array<T> &Array<T>::operator+=(const T &t) {
544 for (std::size_t i = 0; i < m_data.size(); ++i) {
545 this->m_data[i] += t;
546 }
547 return *this;
548}
549
550template <typename T>
551Array<T> &Array<T>::operator-=(const T &t) {
552 for (std::size_t i = 0; i < m_data.size(); ++i) {
553 this->m_data[i] -= t;
554 }
555 return *this;
556}
557
558template <typename T>
559Array<T> &Array<T>::operator*=(const T &t) {
560 for (std::size_t i = 0; i < m_data.size(); ++i) {
561 this->m_data[i] *= t;
562 }
563 return *this;
564}
565
566template <typename T>
567Array<T> &Array<T>::operator/=(const T &t) {
568 for (std::size_t i = 0; i < m_data.size(); ++i) {
569 this->m_data[i] /= t;
570 }
571 return *this;
572}
573
574template <typename T>
575ArrayView<T> Array<T>::row(std::size_t i) {
576 assert(m_Ndim == 2 && "Row only defined for 2D array");
577 return ArrayView(m_data.data() + i * m_sizes[1], m_sizes[1]);
578}
579template <typename T>
580ArrayView<T> Array<T>::col(std::size_t j) {
581 assert(m_Ndim == 2 && "Col only defined for 2D array");
582 return ArrayView(m_data.data() + j, m_sizes[0], m_sizes[1]);
583}
584template <typename T>
585ArrayView<const T> Array<T>::row(std::size_t i) const {
586 assert(m_Ndim == 2 && "Row only defined for 2D array");
587 return ArrayView(m_data.data() + i * m_sizes[1], m_sizes[1]);
588}
589template <typename T>
590ArrayView<const T> Array<T>::col(std::size_t j) const {
591 assert(m_Ndim == 2 && "Col only defined for 2D array");
592 return ArrayView<const T>(m_data.data() + j, m_sizes[0], m_sizes[1]);
593}
594
595} // namespace qip
A view onto a 1D array; used for rows/collumns of ND array. Can have a stride.
Definition Array.hpp:145
auto begin()
Iterator to the beginning.
Definition Array.hpp:183
std::vector< T > vector()
Returns a copy of the array as a std::vector.
Definition Array.hpp:211
Helper template for comparisons. Derive from this to provide !=,>,<=,>=, given == and <.
Definition Template.hpp:32
A constant iterator accounting for a stride.
Definition Array.hpp:80
An iterator accounting for a stride.
Definition Array.hpp:19
std::vector< T > & operator+=(std::vector< T > &a, const std::vector< T > &b)
Provide addition of two vectors:
Definition Vector.hpp:454
qip library: A collection of useful functions
Definition Array.hpp:9
auto NDrange(std::size_t first, Args... rest)
Variadic array of all possible indexes.
Definition Array.hpp:257
T product(T first, Args... rest)
Variadic product - helper function.
Definition Array.hpp:224