1 /**
2  * Type and dimension generic Matrix type.
3  */
4 module bettercmath.matrix;
5 
6 @safe @nogc pure nothrow:
7 
8 version (unittest)
9 {
10     import bettercmath.vector;
11     private alias Vec2 = Vector!(float, 2);
12     private alias Vec3 = Vector!(float, 3);
13     private alias Mat2 = Matrix!(float, 2);
14     private alias Mat23 = Matrix!(float, 2, 3);
15     private alias Mat32 = Matrix!(float, 3, 2);
16     private alias Mat3 = Matrix!(float, 3);
17     private alias Mat34 = Matrix!(float, 3, 4);
18     private alias Mat43 = Matrix!(float, 4, 3);
19     private alias Mat4 = Matrix!(float, 4);
20 }
21 
22 /**
23  * Column-major 2D matrix type.
24  */
25 struct Matrix(T, uint numColumns, uint numRows = numColumns)
26 if (numColumns > 0 && numRows > 0)
27 {
28     import std.algorithm : min;
29     /// Alias for Matrix element type.
30     alias ElementType = T;
31     /// Number of elements in each row, same as the number of columns.
32     enum rowSize = numColumns;
33     /// Number of elements in each column, same as the number of rows.
34     enum columnSize = numRows;
35     /// Minimum dimension between number of rows and number of columns.
36     enum minDimension = min(rowSize, columnSize);
37     /// Total number of elements.
38     enum numElements = rowSize * columnSize;
39     /// Whether matrix is square or not.
40     enum isSquare = rowSize == columnSize;
41 
42     /// Matrix underlying elements.
43     T[numElements] elements = 0;
44 
45     /// Constructs a Matrix specifying all elements.
46     this()(const auto ref T[numElements] values)
47     {
48         this.elements = values;
49     }
50     /// Constructs a Matrix specifying the diagonal value.
51     this(const T diag)
52     {
53         foreach (i; 0 .. minDimension)
54         {
55             this[i, i] = diag;
56         }
57     }
58 
59     /// Copy Matrix values into `target` Matrix of any dimensions.
60     /// If dimensions are not the same, the values at non-overlapping indices
61     /// are ignored.
62     auto ref copyInto(uint C, uint R)(ref return Matrix!(T, C, R) target) const
63     {
64         // If matrices have the same column size, underlying array may be copied at once
65         static if (this.columnSize == target.columnSize)
66         {
67             enum copySize = min(this.numElements, target.numElements);
68             target.elements[0 .. copySize] = this.elements[0 .. copySize];
69         }
70         else
71         {
72             enum columnCopySize = min(this.columnSize, target.columnSize);
73             enum rowCopySize = min(this.rowSize, target.rowSize);
74             foreach (i; 0 .. rowCopySize)
75             {
76                 target[i][0 .. columnCopySize] = this[i][0 .. columnCopySize];
77             }
78         }
79         return target;
80     }
81 
82     /// Returns a copy of Matrix, adjusting dimensions as necessary.
83     /// Non-overlapping indices will stay initialized to 0.
84     U opCast(U : Matrix!(T, C, R), uint C, uint R)() const
85     {
86         typeof(return) result;
87         return copyInto(result);
88     }
89 
90     /// Returns a Range of all columns.
91     auto columns()
92     {
93         import std.range : chunks;
94         return elements[].chunks(columnSize);
95     }
96     /// Returns a Range of all columns.
97     auto columns() const
98     {
99         import std.range : chunks;
100         return elements[].chunks(columnSize);
101     }
102     /// Returns a Range of all rows.
103     auto rows()
104     {
105         import std.range : lockstep, StoppingPolicy;
106         return columns.lockstep(StoppingPolicy.requireSameLength);
107     }
108     /// Returns a Range of all rows.
109     auto rows() const
110     {
111         import std.range : lockstep, StoppingPolicy;
112         return columns.lockstep(StoppingPolicy.requireSameLength);
113     }
114     
115     /// Index a column.
116     inout(T)[] opIndex(size_t i) inout
117     in { assert(i < rowSize, "Index out of bounds"); }
118     do
119     {
120         auto initialIndex = i * columnSize;
121         return elements[initialIndex .. initialIndex + columnSize];
122     }
123     /// Index an element directly.
124     /// Params:
125     ///   i = column index
126     ///   j = row index
127     ref inout(T) opIndex(size_t i, size_t j) inout
128     in { assert(i < rowSize && j < columnSize, "Index out of bounds"); }
129     do
130     {
131         return elements[i*columnSize + j];
132     }
133 
134     /// Row size
135     enum opDollar(size_t pos : 0) = rowSize;
136     /// Column size
137     enum opDollar(size_t pos : 1) = columnSize;
138 
139     /// Constructs a Matrix from all elements in column-major format.
140     static Matrix fromColumns(Args...)(const auto ref Args args)
141     if (args.length == numElements)
142     {
143         return Matrix([args]);
144     }
145     /// Constructs a Matrix from an array of all elements in column-major format.
146     static Matrix fromColumns()(const auto ref T[numElements] elements)
147     {
148         return Matrix(elements);
149     }
150     /// Constructs a Matrix from a 2D array of columns.
151     static Matrix fromColumns()(const auto ref T[rowSize][columnSize] columns)
152     {
153         return Matrix(cast(T[numElements]) columns);
154     }
155 
156     /// Constructs a Matrix from row-major format
157     static Matrix fromRows(Args...)(const auto ref Args args)
158     {
159         return Matrix!(T, columnSize, rowSize).fromColumns(args).transposed;
160     }
161 
162     /// Constructs a Matrix with all diagonal values equal to `diag` and all others equal to 0.
163     static Matrix fromDiagonal(const T diag)
164     {
165         return Matrix(diag);
166     }
167     /// Constructs a Matrix with diagonal values from `diag` and all others equal to 0.
168     static Matrix fromDiagonal(uint N)(const auto ref T[N] diag)
169     if (N <= minDimension)
170     {
171         Matrix mat;
172         foreach (i; 0 .. N)
173         {
174             mat[i, i] = diag[i];
175         }
176         return mat;
177     }
178 
179     /// Returns the result of multiplying `vec` by Matrix.
180     /// If matrix is not square, the resulting array dimension will be different from input.
181     T[columnSize] opBinary(string op : "*")(const auto ref T[rowSize] vec) const
182     {
183         typeof(return) result;
184         foreach (i; 0 .. columnSize)
185         {
186             T sum = 0;
187             foreach (j; 0 .. rowSize)
188             {
189                 sum += this[j, i] * vec[j];
190             }
191             result[i] = sum;
192         }
193         return result;
194     }
195     ///
196     unittest
197     {
198         auto m1 = Mat23.fromRows(1, 2,
199                                  3, 4,
200                                  5, 6);
201         auto v1 = Vec2(1, 2);
202         assert(m1 * v1 == Vec3(1*1 + 2*2,
203                                1*3 + 2*4,
204                                1*5 + 2*6));
205     }
206 
207     /// Returns the result of Matrix multiplication.
208     Matrix!(T, OtherColumns, columnSize) opBinary(string op : "*", uint OtherColumns)(
209         const auto ref Matrix!(T, OtherColumns, rowSize) other
210     ) const
211     {
212         typeof(return) result = void;
213         foreach (i; 0 .. columnSize)
214         {
215             foreach (j; 0 .. OtherColumns)
216             {
217                 T sum = 0;
218                 foreach (k; 0 .. rowSize)
219                 {
220                     sum += this[k, i] * other[j, k];
221                 }
222                 result[j, i] = sum;
223             }
224         }
225         return result;
226     }
227     ///
228     unittest
229     {
230         alias Mat23 = Matrix!(int, 2, 3);
231         alias Mat12 = Matrix!(int, 1, 2);
232 
233         Mat23 m1 = Mat23.fromRows(1, 1,
234                                   2, 2,
235                                   3, 3);
236         Mat12 m2 = Mat12.fromRows(4,
237                                   5);
238         auto result = m1 * m2;
239         assert(result.elements == [
240             1*4 + 1*5,
241             2*4 + 2*5,
242             3*4 + 3*5,
243         ]);
244     }
245 
246     static if (isSquare)
247     {
248         /// Constant Identity matrix (diagonal values 1).
249         enum identity = fromDiagonal(1);
250 
251         /// Inplace matrix multiplication with "*=" operator, only available for square matrices.
252         ref Matrix opOpAssign(string op : "*")(const auto ref Matrix other) return
253         {
254             foreach (i; 0 .. columnSize)
255             {
256                 foreach (j; 0 .. rowSize)
257                 {
258                     T sum = 0;
259                     foreach (k; 0 .. rowSize)
260                     {
261                         sum += this[k, i] * other[j, k];
262                     }
263                     this[j, i] = sum;
264                 }
265             }
266             return this;
267         }
268 
269         // TODO: determinant, inverse matrix, at least for 2x2, 3x3 and 4x4
270     }
271 
272 
273     // Matrix 4x4 methods
274     static if (rowSize == 4 && columnSize == 4)
275     {
276         /// Returns an orthographic projection matrix.
277         /// See_Also: https://www.khronos.org/registry/OpenGL-Refpages/gl2.1/xhtml/glOrtho.xml
278         static Matrix orthographic(T left, T right, T bottom, T top, T near = -1, T far = 1)
279         {
280             Matrix result;
281 
282             result[0, 0] = 2.0 / (right - left);
283             result[1, 1] = 2.0 / (top - bottom);
284             result[2, 2] = 2.0 / (near - far);
285             result[3, 3] = 1.0;
286 
287             result[3, 0] = (left + right) / (left - right);
288             result[3, 1] = (bottom + top) / (bottom - top);
289             result[3, 2] = (far + near) / (near - far);
290 
291             return result;
292         }
293         alias ortho = orthographic;
294 
295         /// Calls `perspective` converting angle from degrees to radians.
296         /// See_Also: perspective
297         static auto perspectiveDegrees(T fovDegrees, T aspectRatio, T near, T far)
298         {
299             import bettercmath.misc : degreesToRadians;
300             return perspective(degreesToRadians(fovDegrees), aspectRatio, near, far);
301         }
302         /// Returns a perspective projection matrix.
303         /// See_Also: https://www.khronos.org/registry/OpenGL-Refpages/gl2.1/xhtml/gluPerspective.xml
304         static Matrix perspective(T fov, T aspectRatio, T near, T far)
305         in { assert(near > 0, "Near clipping pane should be positive"); assert(far > 0, "Far clipping pane should be positive"); }
306         do
307         {
308             Matrix result;
309 
310             import bettercmath.cmath : tan;
311             T cotangent = 1.0 / tan(fov * 0.5);
312 
313             result[0, 0] = cotangent / aspectRatio;
314             result[1, 1] = cotangent;
315             result[2, 3] = -1.0;
316             result[2, 2] = (near + far) / (near - far);
317             result[3, 2] = (2.0 * near * far) / (near - far);
318 
319             return result;
320         }
321     }
322 }
323 
324 /// True if `T` is some kind of Matrix
325 enum isMatrix(T) = is(T : Matrix!U, U...);
326 
327 /// Transpose a square matrix inplace.
328 ref Matrix!(T, C, C) transpose(T, uint C)(ref return Matrix!(T, C, C) mat)
329 {
330     import std.algorithm : swap;
331     foreach (i; 0 .. C)
332     {
333         foreach (j; i+1 .. C)
334         {
335             swap(mat[j, i], mat[i, j]);
336         }
337     }
338     return mat;
339 }
340 ///
341 unittest
342 {
343     auto m1 = Mat2.fromRows(1, 2,
344                             3, 4);
345     transpose(m1);
346     assert(m1 == Mat2.fromRows(1, 3,
347                                2, 4));
348 }
349 
350 /// Returns a transposed copy of `mat`.
351 Matrix!(T, R, C) transposed(T, uint C, uint R)(const auto ref Matrix!(T, C, R) mat)
352 {
353     typeof(return) newMat = void;
354     foreach (i; 0 .. R)
355     {
356         foreach (j; 0 .. C)
357         {
358             newMat[i, j] = mat[j, i];
359         }
360     }
361     return newMat;
362 }
363 ///
364 unittest
365 {
366     float[6] elements = [1, 2, 3, 4, 5, 6];
367     float[6] transposedElements = [1, 4, 2, 5, 3, 6];
368     auto m1 = Mat23.fromColumns(elements);
369     auto m2 = transposed(m1);
370     assert(m2.elements == transposedElements);
371     assert(transposed(m1.transposed) == m1);
372 }
373