ampsci
c++ program for high-precision atomic structure calculations of single-valence systems
Loading...
Searching...
No Matches
Matrix.ipp
1#pragma once
2
3namespace LinAlg {
4
5//==============================================================================
6template <typename T>
7class View {
8 std::size_t m_size;
9 std::size_t m_stride;
10 T *m_data;
11
12public:
13 View(T *data, std::size_t start, std::size_t size, std::size_t stride)
14 : m_size(size), m_stride(stride), m_data(data + long(start)) {}
15
16 std::size_t size() const { return m_size; }
17
19 T &operator[](std::size_t i) { return m_data[i * m_stride]; }
21 T operator[](std::size_t i) const { return m_data[i * m_stride]; }
22
24 T &at(std::size_t i) {
25 assert(i < m_size);
26 return m_data[i * m_stride];
27 }
29 T at(std::size_t i) const {
30 assert(i < m_size);
31 return m_data[i * m_stride];
32 }
34 T &operator()(std::size_t i) { return at(i); }
36 T operator()(std::size_t i) const { return at(i); }
37
38 T *data() { return m_data; }
39};
40
41//==============================================================================
42//==============================================================================
43//==============================================================================
44
45//==============================================================================
46// Returns the determinant. Uses GSL; via LU decomposition. Only works for
47// double/complex<double>
48template <typename T>
50 static_assert(std::is_same_v<T, double> ||
51 std::is_same_v<T, std::complex<double>>,
52 "Determinant only works for double");
53
54 assert(rows() == cols() && "Determinant only defined for square matrix");
55 // Make a copy, since this is destructive. (Performs LU decomp)
56 auto LU = *this; // will become LU decomposed version
57 int sLU = 0;
58 auto gsl_view = LU.as_gsl_view();
59 gsl_permutation *permutn = gsl_permutation_alloc(rows());
60 if constexpr (std::is_same_v<T, double>) {
61 gsl_linalg_LU_decomp(&gsl_view.matrix, permutn, &sLU);
62 gsl_permutation_free(permutn);
63 return gsl_linalg_LU_det(&gsl_view.matrix, sLU);
64 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
65 gsl_linalg_complex_LU_decomp(&gsl_view.matrix, permutn, &sLU);
66 gsl_permutation_free(permutn);
67 const auto gsl_cmplx = gsl_linalg_complex_LU_det(&gsl_view.matrix, sLU);
68 // Can probably avoid this copy? doesn't really matter.
69 return {GSL_REAL(gsl_cmplx), GSL_IMAG(gsl_cmplx)};
70 }
71}
72
73//==============================================================================
74// Inverts the matrix, in place. Uses GSL; via LU decomposition. Only works
75// for double/complex<double>.
76template <typename T>
78 static_assert(
79 std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>,
80 "invert only works for Matrix<double> or Matrix<complex<double>>");
81
82 assert(rows() == cols() && "Inverse only defined for square matrix");
83 int sLU = 0;
84 // gsl_linalg_LU_decomp(m, permutn, &sLU);
85 // gsl_linalg_LU_invx(m, permutn);
86 // In-place inversion gsl_linalg_LU_invx added sometime after GSL v:2.1
87 // Getafix only has 2.1 installed, so can't use this for now
88 auto LU = *this; // copy! to be LU decomposed
89 auto LU_gsl = LU.as_gsl_view();
90 auto iverse_gsl = this->as_gsl_view();
91 gsl_permutation *permutn = gsl_permutation_alloc(m_rows);
92 if constexpr (std::is_same_v<T, double>) {
93 gsl_linalg_LU_decomp(&LU_gsl.matrix, permutn, &sLU);
94 gsl_linalg_LU_invert(&LU_gsl.matrix, permutn, &iverse_gsl.matrix);
95 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
96 gsl_linalg_complex_LU_decomp(&LU_gsl.matrix, permutn, &sLU);
97 gsl_linalg_complex_LU_invert(&LU_gsl.matrix, permutn, &iverse_gsl.matrix);
98 }
99 gsl_permutation_free(permutn);
100 return *this;
101}
102
103template <typename T>
105 auto inverse = *this; // copy
106 return inverse.invert_in_place();
107}
108
109//==============================================================================
110template <typename T>
112 Matrix<T> Tr(m_cols, m_rows);
113 if constexpr (std::is_same_v<T, double>) {
114 auto Tr_gsl = Tr.as_gsl_view();
115 const auto this_gsl = as_gsl_view();
116 gsl_matrix_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
117 } else if constexpr (std::is_same_v<T, float>) {
118 auto Tr_gsl = Tr.as_gsl_view();
119 const auto this_gsl = as_gsl_view();
120 gsl_matrix_float_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
121 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
122 auto Tr_gsl = Tr.as_gsl_view();
123 const auto this_gsl = as_gsl_view();
124 gsl_matrix_complex_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
125 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
126 auto Tr_gsl = Tr.as_gsl_view();
127 const auto this_gsl = as_gsl_view();
128 gsl_matrix_complex_float_transpose_memcpy(&Tr_gsl.matrix, &this_gsl.matrix);
129 } else {
130 // backup, works for any type
131 for (auto i = 0ul; i < Tr.rows(); ++i) {
132 for (auto j = 0ul; j < Tr.cols(); ++j) {
133 Tr[i][j] = (*this)[j][i];
134 }
135 }
136 }
137 return Tr;
138}
139
140//==============================================================================
141// Constructs a diagonal unit matrix (identity)
142template <typename T>
144 assert(m_rows == m_cols && "Can only call make_identity() for square matrix");
145 for (auto i = 0ul; i < m_rows; ++i) {
146 for (auto j = 0ul; j < m_cols; ++j) {
147 at(i, j) = i == j ? T(1) : T(0);
148 }
149 }
150 return *this;
151}
152// Sets all elements to zero
153template <typename T>
155 for (std::size_t i = 0; i < size(); ++i) {
156 m_data[i] = T(0);
157 }
158 return *this;
160
161//==============================================================================
162template <typename T>
164 static_assert(is_complex_v<T>, "conj() only available for complex Matrix");
165 std::vector<T> conj_data;
166 conj_data.reserve(m_data.size());
167 for (std::size_t i = 0; i < m_data.size(); ++i) {
168 conj_data.push_back(std::conj(m_data[i]));
169 }
170 return Matrix<T>{m_rows, m_cols, std::move(conj_data)};
172
173template <typename T>
175 static_assert(is_complex_v<T>, "conj() only available for complex Matrix");
176 for (std::size_t i = 0; i < m_data.size(); ++i) {
177 m_data[i] = std::conj(m_data[i]);
178 }
179 return *this;
180}
181//------------------------------------------------------------------------------
182template <typename T>
183auto Matrix<T>::real() const {
184 static_assert(is_complex_v<T>, "real() only available for complex Matrix");
185 std::vector<typename T::value_type> real_data;
186 real_data.reserve(m_data.size());
187 for (std::size_t i = 0; i < m_data.size(); ++i) {
188 real_data.push_back(std::real(m_data[i]));
189 }
190 return Matrix<typename T::value_type>{m_rows, m_cols, std::move(real_data)};
191}
192//------------------------------------------------------------------------------
193template <typename T>
194auto Matrix<T>::imag() const {
195 static_assert(is_complex_v<T>, "imag() only available for complex Matrix");
196 std::vector<typename T::value_type> imag_data;
197 imag_data.reserve(m_data.size());
198 for (std::size_t i = 0; i < m_data.size(); ++i) {
199 imag_data.push_back(std::imag(m_data[i]));
201 return Matrix<typename T::value_type>{m_rows, m_cols, std::move(imag_data)};
202}
203//------------------------------------------------------------------------------
204template <typename T>
205auto Matrix<T>::complex() const {
206 static_assert(!is_complex_v<T>, "complex() only available for real Matrix");
207 // use move constructor to avoid default Matrix construction
208 std::vector<std::complex<T>> new_data;
209 new_data.reserve(m_data.size());
210 for (std::size_t i = 0; i < m_data.size(); ++i) {
211 new_data.push_back(m_data[i]);
212 }
213 return Matrix<std::complex<T>>{m_rows, m_cols, std::move(new_data)};
214}
215
216//==============================================================================
217template <typename T>
219 assert(rows() == rhs.rows() && cols() == rhs.cols() &&
220 "Matrices must have same dimensions for addition");
221 using namespace qip::overloads;
222 this->m_data += rhs.m_data;
223 return *this;
224}
225template <typename T>
227 assert(rows() == rhs.rows() && cols() == rhs.cols() &&
228 "Matrices must have same dimensions for subtraction");
229 using namespace qip::overloads;
230 this->m_data -= rhs.m_data;
231 return *this;
232}
233template <typename T>
234Matrix<T> &Matrix<T>::operator*=(const T x) {
235 using namespace qip::overloads;
236 this->m_data *= x;
237 return *this;
238}
239template <typename T>
240Matrix<T> &Matrix<T>::operator/=(const T x) {
241 using namespace qip::overloads;
242 this->m_data /= x;
243 return *this;
244}
246//==============================================================================
247// Matrix<T> += T : T assumed to be *Identity!
248template <typename T>
250 // Adds 'a' to diagonal elements (Assume a*Ident)
251 assert(m_rows == m_cols && "Can only call M+a for square matrix");
252 for (auto i = 0ul; i < m_rows; ++i) {
253 at(i, i) += aI;
254 }
255 return *this;
256}
257// Matrix<T> -= T : T assumed to be *Identity!
258template <typename T>
260 // Adds 'a' to diagonal elements (Assume a*Ident)
261 assert(m_rows == m_cols && "Can only call M-a for square matrix");
262 for (auto i = 0ul; i < m_rows; ++i) {
263 at(i, i) -= aI;
264 }
265 return *this;
266}
267
268//==============================================================================
269template <typename T>
271 assert(rows() == a.rows() && cols() == a.cols() &&
272 "Matrices must have same dimensions for mult_elements_by");
273 for (auto i = 0ul; i < m_data.size(); ++i) {
274 m_data[i] *= a.m_data[i];
275 }
276 return *this;
277}
278
279//==============================================================================
280template <typename T>
281[[nodiscard]] Matrix<T> operator*(const Matrix<T> &a, const Matrix<T> &b) {
282 // https://www.gnu.org/software/gsl/doc/html/blas.html
283 assert(a.cols() == b.rows() &&
284 "Matrices a and b must have correct dimension for multiplication");
285 Matrix<T> product(a.rows(), b.cols());
286 const auto a_gsl = a.as_gsl_view();
287 const auto b_gsl = b.as_gsl_view();
288 auto product_gsl = product.as_gsl_view();
289 if constexpr (std::is_same_v<T, double>) {
290 gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, &a_gsl.matrix,
291 &b_gsl.matrix, 0.0, &product_gsl.matrix);
292 } else if constexpr (std::is_same_v<T, float>) {
293 gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1.0f, &a_gsl.matrix,
294 &b_gsl.matrix, 0.0f, &product_gsl.matrix);
295 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
296 gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, GSL_COMPLEX_ONE, &a_gsl.matrix,
297 &b_gsl.matrix, GSL_COMPLEX_ZERO, &product_gsl.matrix);
298 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
299 const gsl_complex_float one{1.0f, 0.0f};
300 const gsl_complex_float zero{0.0f, 0.0f};
301 gsl_blas_cgemm(CblasNoTrans, CblasNoTrans, one, &a_gsl.matrix,
302 &b_gsl.matrix, zero, &product_gsl.matrix);
303 }
304
305 return product;
306}
307
308//==============================================================================
309template <typename T>
311 if constexpr (std::is_same_v<T, double>) {
312 return gsl_matrix_view_array(m_data.data(), m_rows, m_cols);
313 } else if constexpr (std::is_same_v<T, float>) {
314 return gsl_matrix_float_view_array(m_data.data(), m_rows, m_cols);
315 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
316 // reinterpret_cast OK: cppreference.com/w/cpp/numeric/complex
317 return gsl_matrix_complex_view_array(
318 reinterpret_cast<double *>(m_data.data()), m_rows, m_cols);
319 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
320 return gsl_matrix_complex_float_view_array(
321 reinterpret_cast<float *>(m_data.data()), m_rows, m_cols);
322 } else {
323 assert(false && "as_gsl_view() only available for double/float (or complex "
324 "double/float)");
325 }
326}
327
328template <typename T>
330 if constexpr (std::is_same_v<T, double>) {
331 return gsl_matrix_const_view_array(m_data.data(), m_rows, m_cols);
332 } else if constexpr (std::is_same_v<T, float>) {
333 return gsl_matrix_float_const_view_array(m_data.data(), m_rows, m_cols);
334 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
335 return gsl_matrix_complex_const_view_array(
336 reinterpret_cast<const double *>(m_data.data()), m_rows, m_cols);
337 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
338 return gsl_matrix_complex_float_const_view_array(
339 reinterpret_cast<const float *>(m_data.data()), m_rows, m_cols);
340 } else {
341 assert(false && "as_gsl_view() only for available double/float (or complex "
342 "double/float)");
343 }
344}
345
346//==============================================================================
347template <typename T>
348std::ostream &operator<<(std::ostream &os, const Matrix<T> &a) {
349 for (auto i = 0ul; i < a.rows(); ++i) {
350 for (auto j = 0ul; j < a.cols(); ++j) {
351 os << a(i, j) << " ";
352 }
353 os << "\n";
354 }
355 os << "\n";
356 return os;
357}
358
359//==============================================================================
360//==============================================================================
361//==============================================================================
362
363//==============================================================================
364// Helper for equal()
365template <typename T>
366constexpr auto myEps() {
367 if constexpr (std::is_same_v<T, float> ||
368 std::is_same_v<T, std::complex<float>>) {
369 return 1.0e-6f;
370 } else if constexpr (std::is_same_v<T, double> ||
371 std::is_same_v<T, std::complex<double>>) {
372 return 1.0e-12;
373 } else {
374 return 0;
375 }
376}
377
378// Compares two matrices; returns true iff all elements compare relatively to
379// better than eps
380template <typename T>
381bool equal(const Matrix<T> &lhs, const Matrix<T> &rhs, T eps) {
382 if (lhs.rows() != rhs.rows())
383 return false;
384 if (lhs.cols() != rhs.cols())
385 return false;
386 for (auto i = 0ul; i < lhs.rows(); ++i) {
387 for (auto j = 0ul; j < lhs.cols(); ++j) {
388 // need abs on eps in case of complex
389 if (std::abs(lhs(i, j) - rhs(i, j)) >
390 std::abs(eps * (lhs(i, j) + rhs(i, j))))
391 return false;
392 }
393 }
394 return true;
395}
396
397} // namespace LinAlg
Matrix class; row-major.
Definition Matrix.hpp:35
Matrix< T > & make_identity()
Constructs a diagonal unit matrix (identity), in place; only for square.
Definition Matrix.ipp:143
Matrix< T > & invert_in_place()
Inverts the matrix, in place. Uses GSL; via LU decomposition. Only works for double/complex<double>.
Definition Matrix.ipp:77
std::size_t rows() const
Return rows [major index size].
Definition Matrix.hpp:89
Matrix< T > conj() const
Returns conjugate of matrix.
Definition Matrix.ipp:163
Matrix< T > & operator+=(const Matrix< T > &rhs)
Overload standard operators: do what expected.
Definition Matrix.ipp:218
auto complex() const
Converts a real to complex matrix (changes type; returns a complex matrix)
Definition Matrix.ipp:205
auto imag() const
Returns imag part of complex matrix (changes type; returns a real matrix)
Definition Matrix.ipp:194
auto as_gsl_view()
Returns gsl_matrix_view (or _float_view, _complex_view, _complex_float_view). Call ....
Definition Matrix.ipp:310
Matrix< T > & conj_in_place()
Conjugates matrix, in place.
Definition Matrix.ipp:174
Matrix< T > transpose() const
Returns transpose of matrix.
Definition Matrix.ipp:111
Matrix< T > & zero()
Sets all elements to zero, in place.
Definition Matrix.ipp:154
Matrix< T > inverse() const
Returns inverse of the matrix. Leaves original matrix intact. Uses GSL; via LU decomposition....
Definition Matrix.ipp:104
std::size_t cols() const
Return columns [minor index size].
Definition Matrix.hpp:91
T determinant() const
Returns the determinant. Uses GSL; via LU decomposition. Only works for double/complex<double>
Definition Matrix.ipp:49
Matrix< T > & mult_elements_by(const Matrix< T > &a)
Muplitplies all the elements by those of matrix a, in place: M_ij *= a_ij.
Definition Matrix.ipp:270
auto real() const
Returns real part of complex matrix (changes type; returns a real matrix)
Definition Matrix.ipp:183
Proved a "view" onto an array.
Definition Matrix.ipp:7
T operator[](std::size_t i) const
As above, but const.
Definition Matrix.ipp:21
T operator()(std::size_t i) const
As above, but const.
Definition Matrix.ipp:36
T & operator[](std::size_t i)
[] index access (with no range checking). [i][j] returns ith row, jth col
Definition Matrix.ipp:19
T at(std::size_t i) const
As above, but const.
Definition Matrix.ipp:29
T & at(std::size_t i)
() index access (with range checking). (i,j) returns ith row, jth col
Definition Matrix.ipp:24
T & operator()(std::size_t i)
() index access (with range checking). (i,j) returns ith row, jth col
Definition Matrix.ipp:34
Defines Matrix, Vector classes, and linear some algebra functions.
Definition Matrix.hpp:26
bool equal(const Matrix< T > &lhs, const Matrix< T > &rhs, T eps=myEps< T >())
Compares two matrices; returns true iff all elements compare relatively to better than eps.
Definition Matrix.ipp:381
namespace qip::overloads provides operator overloads for std::vector
Definition Vector.hpp:450
T product(T first, Args... rest)
Variadic product - helper function.
Definition Array.hpp:224