soo
3D game math library
Loading...
Searching...
No Matches
matrix.h
1// By notchcamo.
2
3#pragma once
4
5#include <concepts>
6#include <array>
7#include <string>
8#include "soo/impl/vector3.h"
9
10namespace soo
11{
18 template <std::floating_point T, size_t Row, size_t Col>
19 requires (Row > 0 && Col > 0)
20 class Matrix
21 {
22 private:
23 using ColArray = std::array<T, Col>;
24 std::array<ColArray, Row> entries;
25
26 public:
27 // Constructors.
28
29 Matrix() : entries{0} {}
30 Matrix(const std::initializer_list<std::initializer_list<T>>& list)
31 {
32 if (list.size() != Row)
33 {
34 throw std::invalid_argument("Number of rows must be " + std::to_string(Row));
35 }
36
37 size_t row = 0;
38 for (auto rowIt = std::cbegin(list); rowIt != std::cend(list); ++rowIt, ++row)
39 {
40 if (rowIt->size() != Col)
41 {
42 throw std::invalid_argument("Number of columns must be " + std::to_string(Col));
43 }
44
45 size_t col = 0;
46 for (auto colIt = std::cbegin(*rowIt); colIt != std::cend(*rowIt); ++colIt, ++col)
47 {
48 entries[row][col] = *colIt;
49 }
50 }
51 }
52
53 Matrix(const Matrix& rhs) noexcept : entries(rhs.entries) {}
54 Matrix(Matrix&& rhs) noexcept : entries(std::move(rhs.entries)) {}
55
56 ~Matrix() noexcept = default;
57
58 // Operators.
59
60 Matrix& operator=(const Matrix& rhs) noexcept
61 {
62 if (*this != rhs)
63 {
64 entries = rhs.entries;
65 }
66
67 return *this;
68 }
69
70 Matrix& operator=(Matrix&& rhs) noexcept
71 {
72 if (*this != rhs)
73 {
74 entries = std::move(rhs.entries);
75 }
76
77 return *this;
78 }
79
80 ColArray& operator[](const size_t row) noexcept(false)
81 {
82 if (row >= Row)
83 {
84 throw std::out_of_range("Max row is " + std::to_string(Row));
85 }
86
87 return entries[row];
88 }
89
90 const ColArray& operator[](const size_t row) const noexcept(false)
91 {
92 if (row >= Row)
93 {
94 throw std::out_of_range("Max row is " + std::to_string(Row));
95 }
96
97 return entries[row];
98 }
99
100 bool operator==(const Matrix& rhs) const noexcept
101 {
102 for (int row = 0; row < Row; ++row)
103 {
104 if (!util::isEqual(entries[row], rhs.entries[row]))
105 {
106 return false;
107 }
108 }
109
110 return true;
111 }
112
113 Matrix operator+(const Matrix& rhs) const noexcept
114 {
115 Matrix newMatrix{};
116
117 newMatrix.forEachEntry([this, &rhs](T& entry, const size_t row, const size_t col)
118 {
119 entry = entries[row][col] + rhs.entries[row][col];
120 });
121
122 return newMatrix;
123 }
124
125 Matrix operator-(const Matrix& rhs) const noexcept
126 {
127 Matrix newMatrix{};
128
129 newMatrix.forEachEntry([this, &rhs](T& entry, const size_t row, const size_t col)
130 {
131 entry = entries[row][col] - rhs.entries[row][col];
132 });
133
134 return newMatrix;
135 }
136
142 template <size_t R, size_t C>
143 Matrix<T, Row, C> operator*(const Matrix<T, R, C>& rhs) const noexcept(false)
144 {
145 if (Col != R)
146 {
147 throw std::invalid_argument("The number of rows must be " + std::to_string(Col));
148 }
149
150 Matrix<T, Row, C> newMatrix{};
151
152 newMatrix.forEachEntry([this, &rhs](T& entry, const size_t row, const size_t col)
153 {
154 for (size_t c = 0; c < Col; ++c)
155 {
156 entry += entries[row][c] * rhs[c][col];
157 }
158 });
159
160 return newMatrix;
161 }
162
166 Matrix operator*(const T val) const noexcept
167 {
168 Matrix newMatrix{};
169
170 newMatrix.forEachEntry([this, val](T& entry, const size_t row, const size_t col)
171 {
172 entry = entries[row][col] * val;
173 });
174
175 return newMatrix;
176 }
177
181 Vector3<T> operator*(const Vector3<T>& rhs) const requires (Row == 4 && Col == 4)
182 {
183 const T a = entries[0][0], b = entries[0][1], c = entries[0][2], d = entries[0][3];
184 const T e = entries[1][0], f = entries[1][1], g = entries[1][2], h = entries[1][3];
185 const T i = entries[2][0], j = entries[2][1], k = entries[2][2], l = entries[2][3];
186 const T m = entries[3][0], n = entries[3][1], o = entries[3][2], p = entries[3][3];
187
188 const T divisor = m*rhs.x + n*rhs.y + o*rhs.z + p;
189
190 return {
191 (a*rhs.x + b*rhs.y + c*rhs.z + d) / divisor,
192 (e*rhs.x + f*rhs.y + g*rhs.z + h) / divisor,
193 (i*rhs.x + j*rhs.y + k*rhs.z + l) / divisor,
194 };
195 }
196
197 Matrix& operator+=(const Matrix& rhs) noexcept
198 {
199 forEachEntry([&rhs](T& entry, const size_t row, const size_t col)
200 {
201 entry += rhs.entries[row][col];
202 });
203
204 return *this;
205 }
206
207 Matrix& operator-=(const Matrix& rhs) noexcept
208 {
209 forEachEntry([&rhs](T& entry, const size_t row, const size_t col)
210 {
211 entry -= rhs.entries[row][col];
212 });
213
214 return *this;
215 }
216
217 Matrix& operator*=(const T val) noexcept
218 {
219 forEachEntry([val](T& entry, const size_t, const size_t)
220 {
221 entry *= val;
222 });
223
224 return *this;
225 }
226
227 // Methods.
228
229 constexpr static size_t getRowSize() noexcept { return Row; }
230 constexpr static size_t getColSize() noexcept { return Col; }
231
232 constexpr static Matrix createIdentity() noexcept requires (Row == Col)
233 {
234 Matrix identity{};
235
236 for (size_t i = 0; i < Row; ++i)
237 {
238 identity[i][i] = T(1);
239 }
240
241 return identity;
242 }
243
249 template <typename Func>
250 requires std::invocable<Func, T&, const size_t, const size_t>
251 void forEachEntry(const Func& action)
252 {
253 for (size_t row = 0; row < Row; ++row)
254 {
255 for (size_t col = 0; col < Col; ++col)
256 {
257 action(entries[row][col], row, col);
258 }
259 }
260 }
261
265 Matrix<T, Col, Row> transpose() const noexcept
266 {
267 Matrix<T, Col, Row> transposed{};
268
269 transposed.forEachEntry([this](T& entry, const size_t row, const size_t col)
270 {
271 entry = entries[col][row];
272 });
273
274 return transposed;
275 }
276
281 Matrix inverse2x2() const noexcept(false) requires (Row == Col && Row == 2)
282 {
283 const T a = entries[0][0], b = entries[0][1];
284 const T c = entries[1][0], d = entries[1][1];
285
286 const T det = a * d - b * c;
287
288 if (util::isZero(det))
289 {
290 throw std::domain_error("Matrix is not invertible");
291 }
292
293 const T invDet = T(1) / det;
294
295 return {
296 {d * invDet, -b * invDet},
297 {-c * invDet, a * invDet}
298 };
299 }
300
305 Matrix inverse3x3() const noexcept(false) requires (Row == Col && Row == 3)
306 {
307 const T a = entries[0][0], b = entries[0][1], c = entries[0][2];
308 const T d = entries[1][0], e = entries[1][1], f = entries[1][2];
309 const T g = entries[2][0], h = entries[2][1], i = entries[2][2];
310
311 const T det = a*(e*i - f*h) - b*(d*i - f*g) + c*(d*h - e*g);
312
313 if (util::isZero(det))
314 {
315 throw std::domain_error("Matrix is not invertible");
316 }
317
318 const T invDet = T(1) / det;
319
320 return {
321 {(e*i - f*h) * invDet, (c*h - b*i) * invDet, (b*f - c*e) * invDet},
322 {(f*g - d*i) * invDet, (a*i - c*g) * invDet, (c*d - a*f) * invDet},
323 {(d*h - e*g) * invDet, (b*g - a*h) * invDet, (a*e - b*d) * invDet},
324 };
325 }
326
332 Matrix inverseNxN() const noexcept(false) requires (Row == Col)
333 {
334 const Matrix identity = Matrix::createIdentity();
335 std::array<std::array<T, Row*2>, Row> augmented{};
336
337 for (int r = 0; r < Row; ++r)
338 {
339 for (int c = 0; c < Row; ++c)
340 {
341 augmented[r][c] = entries[r][c];
342 augmented[r][c + Row] = identity[r][c];
343 }
344 }
345
346 for (int i = 0; i < Row; ++i)
347 {
348 if (util::isZero(augmented[i][i]))
349 {
350 bool swapped = false;
351
352 for (int k = i + 1; k < Row; ++k)
353 {
354 if (!util::isZero(augmented[k][i]))
355 {
356 std::swap(augmented[i], augmented[k]);
357 swapped = true;
358 break;
359 }
360 }
361
362 if (!swapped)
363 {
364 throw std::domain_error("Matrix is not invertible");
365 }
366 }
367
368 const T invPivot = T(1) / augmented[i][i];
369
370 for (int j = i; j < Row * 2; ++j)
371 {
372 augmented[i][j] *= invPivot;
373 }
374
375 for (int k = 0; k < Row; ++k)
376 {
377 if (k != i && !util::isZero(augmented[k][i]))
378 {
379 const T factor = augmented[k][i];
380
381 for (int j = i; j < Row * 2; ++j)
382 {
383 augmented[k][j] -= factor * augmented[i][j];
384 }
385 }
386 }
387 }
388
389 Matrix inversed{};
390
391 for (int r = 0; r < Row; ++r)
392 {
393 for (int c = Row; c < 2 * Row; ++c)
394 {
395 inversed[r][c-Row] = augmented[r][c];
396 }
397 }
398
399 return inversed;
400 }
401 };
402
403 // Type aliases.
404
405 using Matrix2f = Matrix<float, 2, 2>;
406 using Matrix2d = Matrix<double, 2, 2>;
407 using Matrix2L = Matrix<long double, 2, 2>;
408 using Matrix3f = Matrix<float, 3, 3>;
409 using Matrix3d = Matrix<double, 3, 3>;
410 using Matrix3L = Matrix<long double, 3, 3>;
411 using Matrix4f = Matrix<float, 4, 4>;
412 using Matrix4d = Matrix<double, 4, 4>;
413 using Matrix4L = Matrix<long double, 4, 4>;
414} // namespace soo.
Matrix class.
Definition matrix.h:21
Matrix inverseNxN() const noexcept(false)
Definition matrix.h:332
Matrix< T, Col, Row > transpose() const noexcept
Definition matrix.h:265
Matrix inverse2x2() const noexcept(false)
Definition matrix.h:281
Matrix< T, Row, C > operator*(const Matrix< T, R, C > &rhs) const noexcept(false)
Definition matrix.h:143
Vector3< T > operator*(const Vector3< T > &rhs) const
Definition matrix.h:181
Matrix operator*(const T val) const noexcept
Definition matrix.h:166
Matrix inverse3x3() const noexcept(false)
Definition matrix.h:305
void forEachEntry(const Func &action)
Definition matrix.h:251
3D vector class.
Definition vector3.h:22
bool isZero(const T num, const T tolerance=DEFAULT_TOLERANCE)
Definition util.h:66
bool isEqual(const T a, const T b, const T tolerance=DEFAULT_TOLERANCE)
Definition util.h:23
Namespace for math-on-demand.
Definition exception.h:13