1 /**
2  * Type and dimension generic Matrix type.
3  */
4 module bettercmath.matrix;
5 
6 @safe @nogc pure nothrow:
7 
8 version (unittest)
9 {
10     import bettercmath.vector;
11     private alias Vec2 = Vector!(float, 2);
12     private alias Vec3 = Vector!(float, 3);
13     private alias Mat2 = Matrix!(float, 2);
14     private alias Mat23 = Matrix!(float, 2, 3);
15     private alias Mat32 = Matrix!(float, 3, 2);
16     private alias Mat3 = Matrix!(float, 3);
17     private alias Mat34 = Matrix!(float, 3, 4);
18     private alias Mat43 = Matrix!(float, 4, 3);
19     private alias Mat4 = Matrix!(float, 4);
20 }
21 
22 /**
23  * Column-major 2D matrix type.
24  */
25 struct Matrix(T, uint numColumns, uint numRows = numColumns)
26 if (numColumns > 0 && numRows > 0)
27 {
28     import std.algorithm : min;
29     /// Alias for Matrix element type.
30     alias ElementType = T;
31     /// Number of elements in each row, same as the number of columns.
32     enum rowSize = numColumns;
33     /// Number of elements in each column, same as the number of rows.
34     enum columnSize = numRows;
35     /// Minimum dimension between number of rows and number of columns.
36     enum minDimension = min(rowSize, columnSize);
37     /// Total number of elements.
38     enum numElements = rowSize * columnSize;
39     /// Whether matrix is square or not.
40     enum isSquare = rowSize == columnSize;
41 
42     /// Matrix underlying elements.
43     T[numElements] elements = 0;
44 
45     /// Constructs a Matrix specifying all elements.
46     this()(const auto ref T[numElements] values)
47     {
48         this.elements = values;
49     }
50     /// Constructs a Matrix specifying the diagonal value.
51     this(const T diag)
52     {
53         foreach (i; 0 .. minDimension)
54         {
55             this[i, i] = diag;
56         }
57     }
58 
59     /// Copy Matrix values into `target` Matrix of any dimensions.
60     /// If dimensions are not the same, the values at non-overlapping indices
61     /// are ignored.
62     auto ref copyInto(uint C, uint R)(ref return Matrix!(T, C, R) target) const
63     {
64         // If matrices have the same column size, underlying array may be copied at once
65         static if (this.columnSize == target.columnSize)
66         {
67             enum copySize = min(this.numElements, target.numElements);
68             target.elements[0 .. copySize] = this.elements[0 .. copySize];
69         }
70         else
71         {
72             enum columnCopySize = min(this.columnSize, target.columnSize);
73             enum rowCopySize = min(this.rowSize, target.rowSize);
74             foreach (i; 0 .. rowCopySize)
75             {
76                 target[i][0 .. columnCopySize] = this[i][0 .. columnCopySize];
77             }
78         }
79         return target;
80     }
81 
82     /// Returns a copy of Matrix, adjusting dimensions as necessary.
83     /// Non-overlapping indices will stay initialized to 0.
84     U opCast(U : Matrix!(T, C, R), uint C, uint R)() const
85     {
86         typeof(return) result;
87         return copyInto(result);
88     }
89 
90     /// Returns a Range of all columns.
91     auto columns()
92     {
93         import std.range : chunks;
94         return elements[].chunks(columnSize);
95     }
96     /// Returns a Range of all columns.
97     auto columns() const
98     {
99         import std.range : chunks;
100         return elements[].chunks(columnSize);
101     }
102     /// Returns a Range of all rows.
103     auto rows()
104     {
105         import std.range : lockstep, StoppingPolicy;
106         return columns.lockstep(StoppingPolicy.requireSameLength);
107     }
108     /// Returns a Range of all rows.
109     auto rows() const
110     {
111         import std.range : lockstep, StoppingPolicy;
112         return columns.lockstep(StoppingPolicy.requireSameLength);
113     }
114     
115     /// Index a column.
116     inout(T)[] opIndex(size_t i) inout
117     in { assert(i < rowSize, "Index out of bounds"); }
118     do
119     {
120         auto initialIndex = i * columnSize;
121         return elements[initialIndex .. initialIndex + columnSize];
122     }
123     /// Index an element directly.
124     /// Params:
125     ///   i = column index
126     ///   j = row index
127     ref inout(T) opIndex(size_t i, size_t j) inout
128     in { assert(i < rowSize && j < columnSize, "Index out of bounds"); }
129     do
130     {
131         return elements[i*columnSize + j];
132     }
133 
134     /// Row size
135     enum opDollar(size_t pos : 0) = rowSize;
136     /// Column size
137     enum opDollar(size_t pos : 1) = columnSize;
138 
139     /// Constructs a Matrix from all elements in column-major format.
140     static Matrix fromColumns(Args...)(const auto ref Args args)
141     if (args.length == numElements)
142     {
143         return Matrix([args]);
144     }
145     /// Constructs a Matrix from an array of all elements in column-major format.
146     static Matrix fromColumns()(const auto ref T[numElements] elements)
147     {
148         return Matrix(elements);
149     }
150     /// Constructs a Matrix from a 2D array of columns.
151     static Matrix fromColumns()(const auto ref T[rowSize][columnSize] columns)
152     {
153         return Matrix(cast(T[numElements]) columns);
154     }
155 
156     /// Constructs a Matrix from row-major format
157     static Matrix fromRows(Args...)(const auto ref Args args)
158     {
159         return Matrix!(T, columnSize, rowSize).fromColumns(args).transposed;
160     }
161 
162     /// Constructs a Matrix with all diagonal values equal to `diag` and all others equal to 0.
163     static Matrix fromDiagonal(const T diag)
164     {
165         return Matrix(diag);
166     }
167     /// Constructs a Matrix with diagonal values from `diag` and all others equal to 0.
168     static Matrix fromDiagonal(uint N)(const auto ref T[N] diag)
169     if (N <= minDimension)
170     {
171         Matrix mat;
172         foreach (i; 0 .. N)
173         {
174             mat[i, i] = diag[i];
175         }
176         return mat;
177     }
178 
179     /// Returns the result of multiplying `vec` by Matrix.
180     /// If matrix is not square, the resulting array dimension will be different from input.
181     T[columnSize] opBinary(string op : "*")(const auto ref T[rowSize] vec) const
182     {
183         typeof(return) result;
184         foreach (i; 0 .. columnSize)
185         {
186             T sum = 0;
187             foreach (j; 0 .. rowSize)
188             {
189                 sum += this[j, i] * vec[j];
190             }
191             result[i] = sum;
192         }
193         return result;
194     }
195     unittest
196     {
197         auto m1 = Mat23.fromRows(1, 2,
198                                  3, 4,
199                                  5, 6);
200         auto v1 = Vec2(1, 2);
201         assert(m1 * v1 == Vec3(1*1 + 2*2,
202                                1*3 + 2*4,
203                                1*5 + 2*6));
204     }
205 
206     /// Returns the result of Matrix multiplication.
207     Matrix!(T, OtherColumns, columnSize) opBinary(string op : "*", uint OtherColumns)(
208         const auto ref Matrix!(T, OtherColumns, rowSize) other
209     ) const
210     {
211         typeof(return) result = void;
212         foreach (i; 0 .. columnSize)
213         {
214             foreach (j; 0 .. OtherColumns)
215             {
216                 T sum = 0;
217                 foreach (k; 0 .. rowSize)
218                 {
219                     sum += this[k, i] * other[j, k];
220                 }
221                 result[j, i] = sum;
222             }
223         }
224         return result;
225     }
226     unittest
227     {
228         alias Mat23 = Matrix!(int, 2, 3);
229         alias Mat12 = Matrix!(int, 1, 2);
230 
231         Mat23 m1 = Mat23.fromRows(1, 1,
232                                   2, 2,
233                                   3, 3);
234         Mat12 m2 = Mat12.fromRows(4,
235                                   5);
236         auto result = m1 * m2;
237         assert(result.elements == [
238             1*4 + 1*5,
239             2*4 + 2*5,
240             3*4 + 3*5,
241         ]);
242     }
243 
244     static if (isSquare)
245     {
246         /// Constant Identity matrix (diagonal values 1).
247         enum identity = fromDiagonal(1);
248 
249         /// Inplace matrix multiplication with "*=" operator, only available for square matrices.
250         ref Matrix opOpAssign(string op : "*")(const auto ref Matrix other) return
251         {
252             foreach (i; 0 .. columnSize)
253             {
254                 foreach (j; 0 .. rowSize)
255                 {
256                     T sum = 0;
257                     foreach (k; 0 .. rowSize)
258                     {
259                         sum += this[k, i] * other[j, k];
260                     }
261                     this[j, i] = sum;
262                 }
263             }
264             return this;
265         }
266 
267         // TODO: determinant, inverse matrix, at least for 2x2, 3x3 and 4x4
268     }
269 
270 
271     // Matrix 4x4 methods
272     static if (rowSize == 4 && columnSize == 4)
273     {
274         /// Returns an orthographic projection matrix.
275         /// See_Also: https://www.khronos.org/registry/OpenGL-Refpages/gl2.1/xhtml/glOrtho.xml
276         static Matrix orthographic(T left, T right, T bottom, T top, T near = -1, T far = 1)
277         {
278             Matrix result;
279 
280             result[0, 0] = 2.0 / (right - left);
281             result[1, 1] = 2.0 / (top - bottom);
282             result[2, 2] = 2.0 / (near - far);
283             result[3, 3] = 1.0;
284 
285             result[3, 0] = (left + right) / (left - right);
286             result[3, 1] = (bottom + top) / (bottom - top);
287             result[3, 2] = (far + near) / (near - far);
288 
289             return result;
290         }
291         alias ortho = orthographic;
292 
293         /// Calls `perspective` converting angle from degrees to radians.
294         /// See_Also: perspective
295         static auto perspectiveDegrees(T fovDegrees, T aspectRatio, T near, T far)
296         {
297             import bettercmath.misc : degreesToRadians;
298             return perspective(degreesToRadians(fovDegrees), aspectRatio, near, far);
299         }
300         /// Returns a perspective projection matrix.
301         /// See_Also: https://www.khronos.org/registry/OpenGL-Refpages/gl2.1/xhtml/gluPerspective.xml
302         static Matrix perspective(T fov, T aspectRatio, T near, T far)
303         in { assert(near > 0, "Near clipping pane should be positive"); assert(far > 0, "Far clipping pane should be positive"); }
304         do
305         {
306             Matrix result;
307 
308             import bettercmath.cmath : tan;
309             T cotangent = 1.0 / tan(fov * 0.5);
310 
311             result[0, 0] = cotangent / aspectRatio;
312             result[1, 1] = cotangent;
313             result[2, 3] = -1.0;
314             result[2, 2] = (near + far) / (near - far);
315             result[3, 2] = (2.0 * near * far) / (near - far);
316 
317             return result;
318         }
319     }
320 }
321 
322 /// True if `T` is some kind of Matrix
323 enum isMatrix(T) = is(T : Matrix!U, U...);
324 
325 /// Transpose a square matrix inplace.
326 ref Matrix!(T, C, C) transpose(T, uint C)(ref return Matrix!(T, C, C) mat)
327 {
328     import std.algorithm : swap;
329     foreach (i; 0 .. C)
330     {
331         foreach (j; i+1 .. C)
332         {
333             swap(mat[j, i], mat[i, j]);
334         }
335     }
336     return mat;
337 }
338 unittest
339 {
340     auto m1 = Mat2.fromRows(1, 2,
341                             3, 4);
342     transpose(m1);
343     assert(m1 == Mat2.fromRows(1, 3,
344                                2, 4));
345 }
346 
347 /// Returns a transposed copy of `mat`.
348 Matrix!(T, R, C) transposed(T, uint C, uint R)(const auto ref Matrix!(T, C, R) mat)
349 {
350     typeof(return) newMat = void;
351     foreach (i; 0 .. R)
352     {
353         foreach (j; 0 .. C)
354         {
355             newMat[i, j] = mat[j, i];
356         }
357     }
358     return newMat;
359 }
360 unittest
361 {
362     float[6] elements = [1, 2, 3, 4, 5, 6];
363     float[6] transposedElements = [1, 4, 2, 5, 3, 6];
364     auto m1 = Mat23.fromColumns(elements);
365     auto m2 = transposed(m1);
366     assert(m2.elements == transposedElements);
367     assert(transposed(m1.transposed) == m1);
368 }
369