1 module armos.math.matrix;
2 import armos.math;
3 
4 /++
5 行列を表すstructです.
6 +/
7 struct Matrix(T, int RowSize, int ColSize)if(__traits(isArithmetic, T) && RowSize > 0 && ColSize > 0){
8     alias Matrix!(T, RowSize, ColSize) MatrixType;
9     alias armos.math.Vector!(T, ColSize) VectorType;
10     
11     
12     alias elementType = T;
13     unittest{
14         static assert(is(Matrix!(float, 3, 3).elementType == float));
15     }
16 
17     /++
18     +/
19     enum int rowSize = RowSize;
20 
21     /++
22     +/
23     enum int colSize = ColSize;
24     
25     /++
26     +/
27     static if(RowSize == ColSize){
28         enum int size = RowSize;
29     }
30     unittest{
31         static assert(Matrix!(float, 3, 3).size == 3);
32     }
33 
34     VectorType[RowSize] elements = VectorType();
35 
36     /++
37     +/
38     this(T[][] arr ...){
39         if(arr.length != 0){
40             if(arr.length == RowSize){
41                 foreach (int index, ref VectorType vector; elements) {
42                     vector = VectorType(arr[index]);
43                 }
44             }else{
45                 assert(0);
46             }
47         }
48     }
49 
50     /++
51     +/
52     pure VectorType opIndex(in size_t index)const{
53         return elements[index];
54     }
55     unittest{
56         auto matrix = Matrix2d.zero;
57         assert(matrix[0][0] == 0);
58     }
59     unittest{
60         auto matrix = Matrix2d(
61                 [1.0, 0.0],
62                 [0.0, 1.0]
63                 );
64         assert(matrix[0][0] == 1.00);
65     }
66 
67     /++
68     +/
69     ref VectorType opIndex(in size_t index){
70         return elements[index];
71     }
72     unittest{
73         auto matrix = Matrix2d();
74         matrix[1][0] = 1.0;
75         assert(matrix[1][0] == 1.0);
76     }
77 
78     // const bool opEquals(Object mat){
79     // 	// if(this.rowSize != (cast(MatrixType)mat_tmp).rowSize){return false;}
80     // 	// if(this.colSize != (cast(MatrixType)mat_tmp).colSize){return false;}
81     // 	foreach (int index, VectorType vec; (cast(MatrixType)mat).elements) {
82     // 		if(vec != this.elements[index]){
83     // 			return false;
84     // 		}
85     // 	}
86     // 	return true;
87     // }
88     unittest{
89         auto matrix1 = Matrix2d(
90                 [2.0, 1.0],
91                 [1.0, 2.0]
92                 );
93 
94         auto matrix2 = Matrix2d(
95                 [2.0, 1.0],
96                 [1.0, 2.0]
97                 );
98         assert(matrix1 == matrix2);
99     }
100     unittest{
101         auto matrix1 = Matrix2d();
102         matrix1[1][0] = 1.0;
103         auto matrix2 = Matrix2d();
104         matrix2[1][0] = 1.0;
105         auto matrix3 = Matrix2d();
106         matrix3[1][0] = 2.0;
107         assert(matrix1 == matrix2);
108         assert(matrix1 != matrix3);
109     }
110     unittest{
111         auto matrix1 = Matrix!(double, 1, 2)(
112                 [2.0, 1.0]
113                 );
114 
115         auto matrix2 = Matrix!(double, 2, 1)(
116                 [2.0],
117                 [1.0]
118                 );
119         // assert(matrix1 != matrix2);
120 
121     }
122 
123     static MatrixType zero(){
124         auto zeroMatrix = MatrixType();
125         foreach (ref v; zeroMatrix.elements) {
126             foreach (ref n; v.elements) {
127                 n = T( 0 );
128             }
129         }
130         return zeroMatrix;
131     }
132     unittest{
133         assert(
134                 Matrix!(float, 3, 3).zero == Matrix!(float, 3, 3)(
135                     [0, 0, 0],
136                     [0, 0, 0],
137                     [0, 0, 0]
138                     )
139               );
140 
141     }
142 
143     static if(rowSize == colSize){
144         static MatrixType identity(){
145             auto identityMatrix = MatrixType.zero;
146             for (int i = 0; i < MatrixType.rowSize; i++) {
147                 identityMatrix[i][i] = T(1);
148             }
149             return identityMatrix;
150         }
151     }
152 
153     unittest{
154         assert(
155                 Matrix!(float, 3, 3).identity == Matrix!(float, 3, 3)(
156                     [1, 0, 0],
157                     [0, 1, 0],
158                     [0, 0, 1]
159                     )
160               );
161     }
162 
163     /++
164     +/
165     MatrixType opNeg()const{
166         auto result = MatrixType();
167         foreach (int index, ref var; result.elements) {
168             var = -this[index];
169         }
170         return result;
171     }
172     unittest{
173         auto matrix = Matrix2d();
174         matrix[0][0] = 1.0;
175         assert((-matrix)[0][0] == -1.0);
176     }		
177 
178     /++
179     +/
180     MatrixType opAdd(in MatrixType r)const{
181         auto result = MatrixType();
182         foreach (int index, const VectorType var; r.elements) {
183             result[index] = this[index] + var;
184         }
185         return result;
186     }
187     unittest{
188         auto matrix1 = Matrix2d.zero;
189         matrix1[0][0] = 1.0;
190         auto matrix2 = Matrix2d.zero;
191         matrix2[0][0] = 2.0;
192         matrix2[0][1] = 1.0;
193         auto matrix3 = matrix1 + matrix2;
194         assert(matrix3[0][0] == 3.0);
195         assert(matrix3[0][1] == 1.0);
196     }		
197 
198     /++
199     +/
200     MatrixType opSub(in MatrixType r)const{
201         auto result = MatrixType();
202         foreach (int index, const VectorType var; r.elements) {
203             result[index] = this[index] - var;
204         }
205         return result;
206     }
207     unittest{
208         auto matrix1 = Matrix2d.zero;
209         matrix1[0][0] = 1.0;
210         auto matrix2 = Matrix2d.zero;
211         matrix2[0][0] = 2.0;
212         matrix2[0][1] = 1.0;
213         auto matrix3 = matrix1 - matrix2;
214         assert(matrix3[0][0] == -1.0);
215         assert(matrix3[0][1] == -1.0);
216     }		
217 
218 
219     /++
220     +/
221     MatrixType opAdd(in T v)const{
222         auto result = MatrixType();
223         foreach (int index, const VectorType var; elements) {
224             result[index] = this[index]+v;
225         }
226         return result;
227     }
228     unittest{
229         auto matrix1 = Matrix2d.zero;
230         auto matrix2 = matrix1 + 5.0;
231         auto matrix3 = 3.0 + matrix1;
232         assert(matrix2[1][0] == 5.0);
233         assert(matrix3[1][1] == 3.0);
234     }
235 
236     /++
237     +/
238     MatrixType opSub(in T v)const{
239         auto result = MatrixType();
240         foreach (int index, const VectorType var; elements) {
241             result[index] = this[index]-v;
242         }
243         return result;
244     }
245     unittest{
246         auto matrix1 = Matrix2d.zero;
247         auto matrix2 = matrix1 - 3.0;
248         assert(matrix2[1][0] == -3.0);
249     }
250 
251     /++
252     +/
253     MatrixType opMul(in T v)const{
254         auto result = MatrixType();
255         foreach (int index, const VectorType var; elements) {
256             result[index] = var*v;
257         }
258         return result;
259     }
260     unittest{
261         auto matrix1 = Matrix2d.identity;
262         auto matrix2 = matrix1 * 2.0;
263         assert(matrix2[0][0] == 2.0);
264         auto matrix3 = 2.0 * matrix2;
265         assert(matrix3[1][1] == 4.0);
266     }
267 
268     /++
269     +/
270     MatrixType opMul(in MatrixType mat_r)const{
271         auto result = MatrixType();
272         immutable mat_r_size = mat_r.rowSize;
273         for (int targetRow = 0; targetRow < RowSize; targetRow++) {
274             for (int targetCol = 0; targetCol < ColSize; targetCol++) {
275                 T sum = T(0);
276                 for (int dim = 0; dim < mat_r_size; dim++) {
277                     sum += elements[targetRow][dim] * mat_r[dim][targetCol];
278                 }
279                 result[targetRow][targetCol] = sum;
280             }
281 
282         }
283         return result;
284     }
285     unittest{
286         auto matrix1 = Matrix2f(
287                 [2.0, 0.0],
288                 [1.0, 1.0]
289                 );
290 
291         auto matrix2 = Matrix2f(
292                 [1.0, 1.0],
293                 [0.0, 1.0]
294                 );
295 
296         auto matrix3 = matrix1 * matrix2;
297 
298         auto matrix_answer = Matrix2f(
299                 [2.0, 2.0],
300                 [1.0, 2.0]
301                 );
302 
303         assert(matrix3 == matrix_answer);
304     }
305 
306     /++
307     +/
308     VectorType opMul(in VectorType vec_r)const{
309         auto result = VectorType();
310         for (int targetRow = 0; targetRow < elements.length; targetRow++) {
311             T sum = T(0);
312             foreach (elem; (elements[targetRow] * vec_r).elements) {
313                 sum += elem;
314             }
315             result[targetRow] = sum;
316         }
317         return result;
318     }
319     unittest{
320         auto matrix1 = Matrix2f(
321                 [2.0, 0.0],
322                 [1.0, 1.0]
323                 );
324         auto vector1 = armos.math.Vector2f(1.0, 0.0);
325         auto vector_answer = armos.math.Vector2f(2.0, 1.0);
326         auto vector2 = matrix1 * vector1;
327         assert(vector2 == vector_answer);
328     }
329 
330     MatrixType opDiv(in T v)const{
331         auto result = MatrixType();
332         foreach (int index, const VectorType var; elements) {
333             result[index] = this[index]/v;
334         }
335         return result;
336     }
337     unittest{
338         auto matrix1 = Matrix2d(
339                 [2.0, 4.0],
340                 [3.0, 1.0]
341                 );
342         auto matrix2 = matrix1 / 2.0;
343 
344         auto matrixA = Matrix2d(
345                 [1.0, 2.0],
346                 [1.5, 0.5]
347                 );
348         assert(matrix2 == matrixA);
349     }
350 
351 
352     static if(RowSize == 3 && ColSize == 3 && ( is(T == double) || is(T == float) )){
353         MatrixType inverse(){
354             MatrixType mat = MatrixType(
355                     [elements[1][1]*elements[2][2]-elements[1][2]*elements[2][1], elements[0][2]*elements[2][1]-elements[0][1]*elements[2][2], elements[0][1]*elements[1][2]-elements[0][2]*elements[1][1]],
356                     [elements[1][2]*elements[2][0]-elements[1][0]*elements[2][2], elements[0][0]*elements[2][2]-elements[0][2]*elements[2][0], elements[0][2]*elements[1][0]-elements[0][0]*elements[1][2]],
357                     [elements[1][0]*elements[2][1]-elements[1][1]*elements[2][0], elements[0][1]*elements[2][0]-elements[0][0]*elements[2][1], elements[0][0]*elements[1][1]-elements[0][1]*elements[1][0]]
358                     );
359             return mat/determinant;
360         }
361     }
362     unittest{
363         auto m= Matrix3f(
364                 [1, 2, 0], 
365                 [3, 2, 2], 
366                 [1, 4, 3]
367                 );
368 
369         auto mInv= m.inverse;
370 
371         auto mA= Matrix3f(
372                 [1, 0, 0], 
373                 [0, 1, 0], 
374                 [0, 0, 1]
375                 );
376         assert(mInv*m == mA);
377     }
378 
379     /++
380     +/
381     void setColumnVector(in int column, in VectorType vec){
382         foreach (int i , ref VectorType v; elements) {
383             v[column] = vec[i];
384         }
385     }
386     unittest{
387         auto matrix = Matrix2f();
388         auto vec0 = armos.math.Vector2f(1, 2);
389         auto vec1 = armos.math.Vector2f(3, 4);
390         matrix.setColumnVector(0, vec0);
391         matrix.setColumnVector(1, vec1);
392         assert(matrix == Matrix2f(
393                     [1, 3], 
394                     [2, 4]
395                     ));
396 
397     }
398 
399     /++
400     +/
401     void setRowVector(in int row, in VectorType vec){
402         this[row] = vec;
403     }
404     unittest{
405         auto matrix = Matrix2f();
406         auto vec0 = armos.math.Vector2f(1, 2);
407         auto vec1 = armos.math.Vector2f(3, 4);
408         matrix.setRowVector(0, vec0);
409         matrix.setRowVector(1, vec1);
410         assert(matrix == Matrix2f(
411                     [1, 2], 
412                     [3, 4]
413                     ));
414     }
415 
416     /++
417     +/
418     MatrixType setMatrix(M)(M mat, in int offsetR = 0, in int offsetC = 0)
419         in{
420             assert(M.rowSize<=this.rowSize);
421             assert(M.colSize<=this.colSize);
422             assert(offsetR + M.rowSize<=this.rowSize);
423             assert(offsetC + M.colSize<=this.colSize);
424         }body{
425             for (int x = 0; x < mat.rowSize; x++) {
426                 for (int y = 0; y < mat.colSize; y++) {
427                     this[x+offsetR][y+offsetC] = mat[x][y];
428                 }
429             }
430             return this;
431         }
432     unittest{
433         auto mat44 = Matrix!(double, 4, 4)(
434                 [1, 0, 0, 4],
435                 [0, 1, 0, 0],
436                 [0, 0, 1, 0],
437                 [0, 0, 0, 2]
438                 );
439         auto mat33 = Matrix!(float, 3, 3)(
440                 [2, 1, 0],
441                 [0, 1, 3],
442                 [0, 0, 3]
443                 );
444 
445         auto mat44A = Matrix!(double, 4, 4)(
446                 [1, 2, 1, 0],
447                 [0, 0, 1, 3],
448                 [0, 0, 0, 3],
449                 [0, 0, 0, 2]
450                 );
451         assert( mat44.setMatrix(mat33, 0, 1) == mat44A );
452     }
453 
454     static if(RowSize == 3 && ColSize == 3 && ( is(T == double) || is(T == float) )){
455         T determinant()const{
456             return 
457                 elements[0][0] * elements[1][1] * elements[2][2] -
458                 elements[0][0] * elements[2][1] * elements[1][2] -
459                 elements[1][0] * elements[0][1] * elements[2][2] +
460                 elements[1][0] * elements[2][1] * elements[0][2] +
461                 elements[2][0] * elements[0][1] * elements[1][2] -
462                 elements[2][0] * elements[1][1] * elements[0][2];
463         }
464     }else{
465         /++
466         +/
467         T determinant()const{
468             import std.stdio;
469             T sum = T(0);
470             for (int i = 0; i < RowSize; i++) {
471                 T v = T(1);
472                 for (int j = 0; j < RowSize; j++) {
473                     if (i+j>=RowSize) {
474                         v *= this[i+j-RowSize][j];
475                     }else{
476                         v *= this[i+j][j];
477                     }
478                 }
479                 sum +=v;
480                 v = T(1);
481                 for (int j = 0; j < RowSize; j++) {
482                     if (i-j<0) {
483                         v *= this[i-j+RowSize][j];
484                     }else{
485                         v *= this[i-j][j];
486                     }
487                 }
488                 sum -=v;
489             }
490             return sum;
491         }
492     }
493     unittest{
494         // auto matrix = Matrix3f(
495         // 		[1, 2, 0], 
496         // 		[3, 2, 2], 
497         // 		[1, 4, 3]
498         // 		);
499         // assert(matrix.determinant == 6+4+0 - (8+18+0) );
500     }
501     unittest{
502         import std.stdio;
503         import std.math;
504         auto matrix = Matrix3f(
505                 [0.8, 0, 0],
506                 [0, 1.5, 0],
507                 [0, 0, 0.8]
508                 );
509         assert( approxEqual(matrix.determinant, 0.96) );
510     }
511 
512     /++
513     +/
514     T[RowSize*ColSize] array()const{
515         T[RowSize*ColSize] tmp;
516         for (int i = 0; i < RowSize ; i++) {
517             for (int j = 0; j < ColSize ; j++) {
518                 tmp[i+j*RowSize] = elements[i][j];
519             }
520         }
521         return tmp;
522     }
523     unittest{
524         auto matrix = Matrix3f(
525                 [1, 4, 7], 
526                 [2, 5, 8], 
527                 [3, 6, 9]
528                 );
529         assert(matrix.array == [1, 2, 3, 4, 5, 6, 7, 8, 9]);
530     }
531 
532     /++
533         自身を別の型のMatrixへキャストしたものを返します.キャスト後の型は元のMatrixのRowSize, ColSizeが等しくある必要があります.
534     +/
535     CastedType opCast(CastedType)()const if(CastedType.rowSize == typeof(this).rowSize && CastedType.colSize == typeof(this).colSize){
536         auto mat = CastedType();
537         foreach (int index, const var; this.elements) {
538             mat.elements[index] = cast(CastedType.VectorType)elements[index];
539         }
540         return mat;
541     }
542     unittest{
543         auto m_f= Matrix3f(
544                 [1, 4, 7], 
545                 [2, 5, 8], 
546                 [3, 6, 9]
547         );
548 
549         auto m_i= Matrix3i.zero;
550         
551         m_i = cast(Matrix3i)m_f;
552         
553         assert(m_i[0][0] == 1);
554 
555         //Invalid cast. Different size.
556         assert(!__traits(compiles, {
557             auto m_f= Matrix3f(
558                     [1, 4, 7], 
559                     [2, 5, 8], 
560                     [3, 6, 9]
561             );
562             auto m_i= Matrix4i.zero;
563             vec_i = cast(Vector3i)vec_f;
564         }));
565     }
566 }
567 
568 alias Matrix!(int, 2, 2) Matrix2i;
569 alias Matrix!(int, 3, 3) Matrix3i;
570 alias Matrix!(int, 4, 4) Matrix4i;
571 alias Matrix!(float, 2, 2) Matrix2f;
572 alias Matrix!(float, 3, 3) Matrix3f;
573 alias Matrix!(float, 4, 4) Matrix4f;
574 alias Matrix!(double, 2, 2) Matrix2d;
575 alias Matrix!(double, 3, 3) Matrix3d;
576 alias Matrix!(double, 4, 4) Matrix4d;
577 
578 template SquareMatrix(T, int D){
579     alias SquareMatrix = Matrix!(T, D, D);
580 }
581 unittest{
582     static assert(isSquareMatrix!(SquareMatrix!(float, 4)));
583 }
584 
585 /++
586 +/
587 template isMatrix(M) {
588     public{
589         enum bool isMatrix = __traits(compiles, (){
590                 static assert(is(M == Matrix!(typeof(M()[0][0]), M.rowSize, M.colSize)));
591                 });
592     }//public
593 }//template isMatrix
594 unittest{
595     static assert(isMatrix!(Matrix!(float, 3, 3)));
596     static assert(!isMatrix!(float));
597 }
598 
599 /++
600 +/
601 template isSquareMatrix(M){
602     public{
603         enum bool isSquareMatrix = __traits(compiles, (){
604                 static assert(M.rowSize == M.colSize);
605                 static assert(isMatrix!(M));
606                 });
607     }//public
608 }//template isSquareMatrix
609 unittest{
610     static assert(isSquareMatrix!(Matrix!(float, 3, 3)));
611     static assert(!isSquareMatrix!(Matrix!(float, 2, 3)));
612     static assert(!isSquareMatrix!(float));
613 }