ampsci
High-precision calculations for one- and two-valence atomic systems
Matrix.ipp
1#pragma once
2
3namespace LinAlg {
4
5//==============================================================================
6// Returns the determinant. Uses GSL; via LU decomposition. Only works for
7// double/complex<double>
8template <typename T>
10 static_assert(std::is_same_v<T, double> ||
11 std::is_same_v<T, std::complex<double>>,
12 "Determinant only works for double");
13
14 assert(rows() == cols() && "Determinant only defined for square matrix");
15 // Make a copy, since this is destructive. (Performs LU decomp)
16 auto LU = *this; // will become LU decomposed version
17 int sLU = 0;
18 auto gsl_view = LU.as_gsl_view();
19 gsl_permutation *permutn = gsl_permutation_alloc(rows());
20 if constexpr (std::is_same_v<T, double>) {
21 gsl_linalg_LU_decomp(&gsl_view.matrix, permutn, &sLU);
22 gsl_permutation_free(permutn);
23 return gsl_linalg_LU_det(&gsl_view.matrix, sLU);
24 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
25 gsl_linalg_complex_LU_decomp(&gsl_view.matrix, permutn, &sLU);
26 gsl_permutation_free(permutn);
27 const auto gsl_cmplx = gsl_linalg_complex_LU_det(&gsl_view.matrix, sLU);
28 // Can probably avoid this copy? doesn't really matter.
29 return {GSL_REAL(gsl_cmplx), GSL_IMAG(gsl_cmplx)};
30 }
31}
32
33//==============================================================================
34// Inverts the matrix, in place. Uses GSL; via LU decomposition. Only works
35// for double/complex<double>.
36template <typename T>
38 static_assert(
39 std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>,
40 "invert only works for Matrix<double> or Matrix<complex<double>>");
41
42 assert(rows() == cols() && "Inverse only defined for square matrix");
43 int sLU = 0;
44 // gsl_linalg_LU_decomp(m, permutn, &sLU);
45 // gsl_linalg_LU_invx(m, permutn);
46 // In-place inversion gsl_linalg_LU_invx added sometime after GSL v:2.1
47 // Getafix only has 2.1 installed, so can't use this for now
48 auto LU = *this; // copy! to be LU decomposed
49 auto LU_gsl = LU.as_gsl_view();
50 auto iverse_gsl = this->as_gsl_view();
51 gsl_permutation *permutn = gsl_permutation_alloc(m_rows);
52 if constexpr (std::is_same_v<T, double>) {
53 gsl_linalg_LU_decomp(&LU_gsl.matrix, permutn, &sLU);
54 gsl_linalg_LU_invert(&LU_gsl.matrix, permutn, &iverse_gsl.matrix);
55 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
56 gsl_linalg_complex_LU_decomp(&LU_gsl.matrix, permutn, &sLU);
57 gsl_linalg_complex_LU_invert(&LU_gsl.matrix, permutn, &iverse_gsl.matrix);
58 }
59 gsl_permutation_free(permutn);
60 return *this;
61}
62
63//==============================================================================
64template <typename T>
66 Matrix<T> Tr(m_cols, m_rows);
67 if constexpr (std::is_same_v<T, double>) {
68 auto Tr_gsl = Tr.as_gsl_view();
69 const auto this_gsl = as_gsl_view();
70 gsl_matrix_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
71 } else if constexpr (std::is_same_v<T, float>) {
72 auto Tr_gsl = Tr.as_gsl_view();
73 const auto this_gsl = as_gsl_view();
74 gsl_matrix_float_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
75 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
76 auto Tr_gsl = Tr.as_gsl_view();
77 const auto this_gsl = as_gsl_view();
78 gsl_matrix_complex_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
79 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
80 auto Tr_gsl = Tr.as_gsl_view();
81 const auto this_gsl = as_gsl_view();
82 gsl_matrix_complex_float_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
83 } else {
84 // backup, works for any type
85 for (auto i = 0ul; i < Tr.rows(); ++i) {
86 for (auto j = 0ul; j < Tr.cols(); ++j) {
87 Tr[i][j] = (*this)[j][i];
88 }
89 }
90 }
91 return Tr;
92}
93
94//==============================================================================
95// Constructs a diagonal unit matrix (identity)
96template <typename T>
98 assert(m_rows == m_cols && "Can only call make_identity() for square matrix");
99 for (auto i = 0ul; i < m_rows; ++i) {
100 for (auto j = 0ul; j < m_cols; ++j) {
101 at(i, j) = i == j ? T(1) : T(0);
102 }
103 }
104 return *this;
105}
106// Sets all elements to zero
107template <typename T>
109 for (std::size_t i = 0; i < size(); ++i) {
110 m_data[i] = T(0);
111 }
112 return *this;
113}
114
115//==============================================================================
116template <typename T>
118 static_assert(is_complex_v<T>, "conj() only available for complex Matrix");
119 std::vector<T> conj_data;
120 conj_data.reserve(m_data.size());
121 for (std::size_t i = 0; i < m_data.size(); ++i) {
122 conj_data.push_back(std::conj(m_data[i]));
123 }
124 return Matrix<T>{m_rows, m_cols, std::move(conj_data)};
125}
126
127template <typename T>
129 static_assert(is_complex_v<T>, "conj() only available for complex Matrix");
130 for (std::size_t i = 0; i < m_data.size(); ++i) {
131 m_data[i] = std::conj(m_data[i]);
132 }
133 return *this;
134}
135//------------------------------------------------------------------------------
136template <typename T>
137auto Matrix<T>::real() const {
138 static_assert(is_complex_v<T>, "real() only available for complex Matrix");
139 std::vector<typename T::value_type> real_data;
140 real_data.reserve(m_data.size());
141 for (std::size_t i = 0; i < m_data.size(); ++i) {
142 real_data.push_back(std::real(m_data[i]));
143 }
144 return Matrix<typename T::value_type>{m_rows, m_cols, std::move(real_data)};
145}
146//------------------------------------------------------------------------------
147template <typename T>
148auto Matrix<T>::imag() const {
149 static_assert(is_complex_v<T>, "imag() only available for complex Matrix");
150 std::vector<typename T::value_type> imag_data;
151 imag_data.reserve(m_data.size());
152 for (std::size_t i = 0; i < m_data.size(); ++i) {
153 imag_data.push_back(std::imag(m_data[i]));
154 }
155 return Matrix<typename T::value_type>{m_rows, m_cols, std::move(imag_data)};
156}
157//------------------------------------------------------------------------------
158template <typename T>
159auto Matrix<T>::complex() const {
160 static_assert(!is_complex_v<T>, "complex() only available for real Matrix");
161 // use move constructor to avoid default Matrix construction
162 std::vector<std::complex<T>> new_data;
163 new_data.reserve(m_data.size());
164 for (std::size_t i = 0; i < m_data.size(); ++i) {
165 new_data.push_back(m_data[i]);
166 }
167 return Matrix<std::complex<T>>{m_rows, m_cols, std::move(new_data)};
168}
169
170//==============================================================================
171template <typename T>
173 assert(rows() == rhs.rows() && cols() == rhs.cols() &&
174 "Matrices must have same dimensions for addition");
175 using namespace qip::overloads;
176 this->m_data += rhs.m_data;
177 return *this;
178}
179template <typename T>
181 assert(rows() == rhs.rows() && cols() == rhs.cols() &&
182 "Matrices must have same dimensions for subtraction");
183 using namespace qip::overloads;
184 this->m_data -= rhs.m_data;
185 return *this;
186}
187template <typename T>
189 using namespace qip::overloads;
190 this->m_data *= x;
191 return *this;
192}
193template <typename T>
195 using namespace qip::overloads;
196 this->m_data /= x;
197 return *this;
198}
199
200//==============================================================================
201// Matrix<T> += T : T assumed to be *Identity!
202template <typename T>
204 // Adds 'a' to diagonal elements (Assume a*Ident)
205 assert(m_rows == m_cols && "Can only call M+a for square matrix");
206 for (auto i = 0ul; i < m_rows; ++i) {
207 at(i, i) += aI;
208 }
209 return *this;
210}
211// Matrix<T> -= T : T assumed to be *Identity!
212template <typename T>
214 // Adds 'a' to diagonal elements (Assume a*Ident)
215 assert(m_rows == m_cols && "Can only call M-a for square matrix");
216 for (auto i = 0ul; i < m_rows; ++i) {
217 at(i, i) -= aI;
218 }
219 return *this;
220}
221
222//==============================================================================
223template <typename T>
225 assert(rows() == a.rows() && cols() == a.cols() &&
226 "Matrices must have same dimensions for mult_elements_by");
227 for (auto i = 0ul; i < m_data.size(); ++i) {
228 m_data[i] *= a.m_data[i];
229 }
230 return *this;
231}
232
233//==============================================================================
234template <typename T>
235[[nodiscard]] Matrix<T> operator*(const Matrix<T> &a, const Matrix<T> &b) {
236 // https://www.gnu.org/software/gsl/doc/html/blas.html
237 assert(a.cols() == b.rows() &&
238 "Matrices a and b must have correct dimension for multiplication");
239 Matrix<T> product(a.rows(), b.cols());
240
241 GEMM(a, b, &product);
242
243 return product;
244}
245
246//==============================================================================
247template <typename T>
248void GEMM(const Matrix<T> &a, const Matrix<T> &b, Matrix<T> *c, bool trans_A,
249 bool trans_B) {
250 assert(c);
251
252 const auto ta = to_cblas_trans(trans_A);
253 const auto tb = to_cblas_trans(trans_B);
254
255 // Effective dimensions:
256 // op(A): (trans_A ? a.cols x a.rows : a.rows x a.cols)
257 // op(B): (trans_B ? b.cols x b.rows : b.rows x b.cols)
258 const int A_rows = static_cast<int>(trans_A ? a.cols() : a.rows());
259 const int A_cols = static_cast<int>(trans_A ? a.rows() : a.cols());
260 const int B_rows = static_cast<int>(trans_B ? b.cols() : b.rows());
261 const int B_cols = static_cast<int>(trans_B ? b.rows() : b.cols());
262
263 // GEMM sizes: C = op(A) * op(B), where
264 // M = rows(op(A)), N = cols(op(B)), K = cols(op(A)) = rows(op(B))
265 const int M = A_rows;
266 const int N = B_cols;
267 const int K = A_cols;
268
269 assert(A_cols == B_rows && "op(A) cols must equal op(B) rows");
270 assert(static_cast<int>(c->rows()) == M && static_cast<int>(c->cols()) == N &&
271 "Output matrix c must be sized MxN");
272
273 // Row-major leading dimensions:
274 // lda = number of columns in A's *storage* (i.e., a.cols()) regardless of trans
275 // same for b, c
276 const int lda = static_cast<int>(a.cols());
277 const int ldb = static_cast<int>(b.cols());
278 const int ldc = static_cast<int>(c->cols());
279
280 if constexpr (std::is_same_v<T, double>) {
281 cblas_dgemm(CblasRowMajor, ta, tb, M, N, K, 1.0, a.data(), lda, b.data(),
282 ldb, 0.0, c->data(), ldc);
283
284 } else if constexpr (std::is_same_v<T, float>) {
285 cblas_sgemm(CblasRowMajor, ta, tb, M, N, K, 1.0f, a.data(), lda, b.data(),
286 ldb, 0.0f, c->data(), ldc);
287
288 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
289 const std::complex<double> alpha{1.0, 0.0};
290 const std::complex<double> beta{0.0, 0.0};
291 cblas_zgemm(CblasRowMajor, ta, tb, M, N, K, &alpha, a.data(), lda, b.data(),
292 ldb, &beta, c->data(), ldc);
293
294 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
295 const std::complex<float> alpha{1.0f, 0.0f};
296 const std::complex<float> beta{0.0f, 0.0f};
297 cblas_cgemm(CblasRowMajor, ta, tb, M, N, K, &alpha, a.data(), lda, b.data(),
298 ldb, &beta, c->data(), ldc);
299
300 } else {
301 static_assert(!sizeof(T), "GEMM: unsupported scalar type");
302 }
303}
304
305//==============================================================================
306// M_ab = A_ai B_aj C_ij D_ib E_jb, using BLAS
307template <typename T>
308void PENTA_GEMM(const Matrix<T> &A, const Matrix<T> &B, const Matrix<T> &C,
309 const Matrix<T> &D, const Matrix<T> &E, Matrix<T> *pM) {
310 //
311 const auto N = A.rows(); // assume all square
312 assert(A.cols() == A.rows() && "Must be square");
313
314 Matrix<T> X(N, N);
315 Matrix<T> Y(N, N);
316 auto &M = *pM;
317
318 // M_ab = A_ai B_aj C_ij D_ib E_jb
319 // = A_ai B_aj X(i)_jb D_ib
320 // = A_ai Y(i)_ab D_ib
321 // X(i)_jb = C_ij * E_j2;
322 // Y(i)_aj = B_ij * X(i)_jb
323
324 for (std::size_t i = 0; i < N; ++i) {
325 for (std::size_t j = 0; j < N; ++j) {
326 const auto cij = C[i][j];
327 for (std::size_t b = 0; b < N; ++b) {
328 X[j][b] = cij * E[j][b];
329 }
330 }
331 GEMM(B, X, &Y);
332 for (std::size_t a = 0; a < N; ++a) {
333 for (std::size_t b = 0; b < N; ++b) {
334 M[a][b] += A[a][i] * Y[a][b] * D[i][b];
335 }
336 }
337 }
338}
339
340// M_ab = A_ai B_aj C_ij D_ib E_jb
341template <typename T, bool PARALLEL>
342void PENTA(const Matrix<T> &A, const Matrix<T> &B, const Matrix<T> &C,
343 const Matrix<T> &D, const Matrix<T> &E, Matrix<T> *pM) {
344 //
345 const auto N = A.rows(); // assume all square
346 assert(A.cols() == A.rows() && "Must be square");
347
348 auto &M = *pM;
349
350 // M_ab = A_ai B_aj C_ij D_ib E_jb
351 if constexpr (PARALLEL) {
352
353#pragma omp parallel for collapse(2)
354 for (std::size_t a = 0; a < N; ++a) {
355 for (std::size_t b = 0; b < N; ++b) {
356 const T *Ba = &B[a][0];
357 T Mab = T(0);
358 for (std::size_t i = 0; i < N; ++i) {
359 const auto AaiDib = A[a][i] * D[i][b];
360 const T *Ci = &C[i][0];
361 for (std::size_t j = 0; j < N; ++j) {
362 Mab += AaiDib * Ba[j] * Ci[j] * E[j][b];
363 }
364 }
365 M[a][b] = Mab;
366 }
367 }
368
369 } else {
370
371 for (std::size_t a = 0; a < N; ++a) {
372 const T *Ba = &B[a][0];
373 for (std::size_t b = 0; b < N; ++b) {
374 T Mab = T(0);
375 for (std::size_t i = 0; i < N; ++i) {
376 const auto AaiDib = A[a][i] * D[i][b];
377 const T *Ci = &C[i][0];
378 for (std::size_t j = 0; j < N; ++j) {
379 Mab += AaiDib * Ba[j] * Ci[j] * E[j][b];
380 }
382 M[a][b] = Mab;
383 }
385 }
386}
387
388//==============================================================================
389template <typename T>
391 if constexpr (std::is_same_v<T, double>) {
392 return gsl_matrix_view_array(m_data.data(), m_rows, m_cols);
393 } else if constexpr (std::is_same_v<T, float>) {
394 return gsl_matrix_float_view_array(m_data.data(), m_rows, m_cols);
395 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
396 // reinterpret_cast OK: cppreference.com/w/cpp/numeric/complex
397 return gsl_matrix_complex_view_array(
398 reinterpret_cast<double *>(m_data.data()), m_rows, m_cols);
399 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
400 return gsl_matrix_complex_float_view_array(
401 reinterpret_cast<float *>(m_data.data()), m_rows, m_cols);
402 } else {
403 assert(false && "as_gsl_view() only available for double/float (or complex "
404 "double/float)");
405 }
406}
407
408template <typename T>
410 if constexpr (std::is_same_v<T, double>) {
411 return gsl_matrix_const_view_array(m_data.data(), m_rows, m_cols);
412 } else if constexpr (std::is_same_v<T, float>) {
413 return gsl_matrix_float_const_view_array(m_data.data(), m_rows, m_cols);
414 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
415 return gsl_matrix_complex_const_view_array(
416 reinterpret_cast<const double *>(m_data.data()), m_rows, m_cols);
417 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
418 return gsl_matrix_complex_float_const_view_array(
419 reinterpret_cast<const float *>(m_data.data()), m_rows, m_cols);
420 } else {
421 assert(false && "as_gsl_view() only for available double/float (or complex "
422 "double/float)");
423 }
424}
425
426//==============================================================================
427template <typename T>
428std::ostream &operator<<(std::ostream &os, const Matrix<T> &a) {
429 for (auto i = 0ul; i < a.rows(); ++i) {
430 for (auto j = 0ul; j < a.cols(); ++j) {
431 os << a(i, j) << " ";
432 }
433 os << "\n";
434 }
435 os << "\n";
436 return os;
438
439//==============================================================================
440//==============================================================================
441//==============================================================================
443//==============================================================================
444// Helper for equal()
445template <typename T>
446constexpr auto myEps() {
447 if constexpr (std::is_same_v<T, float> ||
448 std::is_same_v<T, std::complex<float>>) {
449 return 1.0e-6f;
450 } else if constexpr (std::is_same_v<T, double> ||
451 std::is_same_v<T, std::complex<double>>) {
452 return 1.0e-12;
453 } else {
454 return 0;
455 }
457
458// Compares two matrices; returns true iff all elements compare relatively to
459// better than eps
460template <typename T>
461bool equal(const Matrix<T> &lhs, const Matrix<T> &rhs, T eps) {
462 if (lhs.rows() != rhs.rows())
463 return false;
464 if (lhs.cols() != rhs.cols())
465 return false;
466 for (auto i = 0ul; i < lhs.rows(); ++i) {
467 for (auto j = 0ul; j < lhs.cols(); ++j) {
468 // need abs on eps in case of complex
469 if (std::abs(lhs(i, j) - rhs(i, j)) >
470 std::abs(eps * (lhs(i, j) + rhs(i, j))))
471 return false;
473 }
474 return true;
475}
476
477} // namespace LinAlg
Row-major dense matrix with arithmetic and linear algebra support.
Definition Matrix.hpp:209
Matrix< T > & make_identity()
Constructs a diagonal unit matrix (identity), in place; only for square.
Definition Matrix.ipp:97
Matrix< T > & invert_in_place()
Inverts the matrix in place via LU decomposition (GSL).
Definition Matrix.ipp:37
T * data()
Pointer to first element; for std::complex<T> this is complex<T>*, not T*.
Definition Matrix.hpp:292
std::size_t rows() const
Return rows [major index size].
Definition Matrix.hpp:282
Matrix< T > conj() const
Returns conjugate of matrix.
Definition Matrix.ipp:117
Matrix< T > & operator+=(const Matrix< T > &rhs)
In-place elementwise addition; dimensions must match.
Definition Matrix.ipp:172
auto complex() const
Converts a real to complex matrix (changes type; returns a complex matrix)
Definition Matrix.ipp:159
auto imag() const
Returns imag part of complex matrix (changes type; returns a real matrix)
Definition Matrix.ipp:148
auto as_gsl_view()
Returns a GSL matrix view for use with GSL functions (no copy).
Definition Matrix.ipp:390
Matrix< T > & conj_in_place()
Conjugates matrix, in place.
Definition Matrix.ipp:128
Matrix< T > & operator/=(const T x)
In-place scalar divide: M_ij /= x.
Definition Matrix.ipp:194
Matrix< T > transpose() const
Returns the transpose of the matrix.
Definition Matrix.ipp:65
Matrix< T > & zero()
Sets all elements to zero, in place.
Definition Matrix.ipp:108
Matrix< T > & operator*=(const T x)
In-place scalar multiply: M_ij *= x.
Definition Matrix.ipp:188
std::size_t cols() const
Return columns [minor index size].
Definition Matrix.hpp:284
T determinant() const
Returns the determinant via LU decomposition (GSL).
Definition Matrix.ipp:9
Matrix< T > & mult_elements_by(const Matrix< T > &a)
Elementwise multiply in place: M_ij *= a_ij.
Definition Matrix.ipp:224
auto real() const
Returns real part of complex matrix (changes type; returns a real matrix)
Definition Matrix.ipp:137
Matrix< T > & operator-=(const Matrix< T > &rhs)
In-place elementwise subtraction; dimensions must match.
Definition Matrix.ipp:180
Linear algebra: matrices, vectors, views, and solvers.
Definition Matrix.hpp:54
void PENTA(const Matrix< T > &A, const Matrix< T > &B, const Matrix< T > &C, const Matrix< T > &D, const Matrix< T > &E, Matrix< T > *M)
5-matrix contraction for N*N matrices: M_ab = A_ai B_aj C_ij D_ib E_jb, without BLAS,...
Definition Matrix.ipp:342
void PENTA_GEMM(const Matrix< T > &A, const Matrix< T > &B, const Matrix< T > &C, const Matrix< T > &D, const Matrix< T > &E, Matrix< T > *M)
5-matrix contraction for N*N matrices: M_ab = A_ai B_aj C_ij D_ib E_jb, with BLAS
Definition Matrix.ipp:308
void GEMM(const Matrix< T > &a, const Matrix< T > &b, Matrix< T > *c, bool trans_A=false, bool trans_B=false)
Matrix multiplication C = op(A) * op(B) via CBLAS GEMM (row-major).
Definition Matrix.ipp:248
CBLAS_TRANSPOSE to_cblas_trans(bool trans)
Converts bool to CBLAS_TRANSPOSE enum (CblasTrans if true, CblasNoTrans if false)
Definition Matrix.hpp:547
constexpr auto myEps()
Default relative tolerance for equal(): 1e-6 for float, 1e-12 for double.
Definition Matrix.ipp:446
bool equal(const Matrix< T > &lhs, const Matrix< T > &rhs, T eps=myEps< T >())
Compares two matrices element-wise to within a relative tolerance.
Definition Matrix.ipp:461
Operator overloads for std::vector.
Definition Vector.hpp:503