1
1
#include <immintrin.h>
2
2
#include <omp.h>
3
3
#include <stdio.h>
4
+ #include "../../include/Core/matrix_multiply.h"
4
5
5
6
/**
6
7
* @brief Performs a matrix multiplication using SIMD instructions (AVX).
18
19
void matrix_multiply_simd (const float * A , const float * B_T , float * C ,
19
20
int M , int N , int K , float scale )
20
21
{
22
+ if (A == NULL || B_T == NULL || C == NULL )
23
+ {
24
+ fprintf (stderr , "Error: Null pointer passed to matrix_multiply_simd.\n" );
25
+ return ;
26
+ }
27
+
28
+ if (M <= 0 || N <= 0 || K <= 0 )
29
+ {
30
+ fprintf (stderr , "Error: Invalid matrix dimensions passed to matrix_multiply_simd.\n" );
31
+ return ;
32
+ }
33
+
21
34
#pragma omp parallel for collapse(2)
22
35
for (int i = 0 ; i < M ; i ++ )
23
36
{
@@ -29,7 +42,7 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
29
42
for (k = 0 ; k <= K - 8 ; k += 8 )
30
43
{
31
44
__m256 a = _mm256_loadu_ps (& A [i * K + k ]);
32
- __m256 b = _mm256_loadu_ps (& B_T [j * K + k ]); // Access row in transposed B
45
+ __m256 b = _mm256_loadu_ps (& B_T [j * K + k ]);
33
46
sum = _mm256_add_ps (sum , _mm256_mul_ps (a , b ));
34
47
}
35
48
@@ -62,8 +75,70 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
62
75
*/
63
76
void transpose_matrix (const float * B , float * B_T , int K , int N )
64
77
{
78
+ if (B == NULL || B_T == NULL )
79
+ {
80
+ fprintf (stderr , "Error: Null pointer passed to transpose_matrix.\n" );
81
+ return ;
82
+ }
83
+
84
+ if (K <= 0 || N <= 0 )
85
+ {
86
+ fprintf (stderr , "Error: Invalid matrix dimensions passed to transpose_matrix.\n" );
87
+ return ;
88
+ }
89
+
65
90
#pragma omp parallel for collapse(2)
66
91
for (int i = 0 ; i < K ; ++ i )
67
92
for (int j = 0 ; j < N ; ++ j )
68
93
B_T [j * K + i ] = B [i * N + j ];
69
94
}
95
+
96
+ /**
97
+ * @brief Performs a matrix multiplication, checking if the second matrix is already transposed.
98
+ *
99
+ * C = A * B * scale
100
+ *
101
+ * @param A Pointer to the first matrix (M x K).
102
+ * @param B Pointer to the second matrix (K x N).
103
+ * @param C Pointer to the result matrix (M x N).
104
+ * @param M Number of rows in matrix A.
105
+ * @param N Number of columns in matrix B.
106
+ * @param K Number of columns in matrix A and rows in matrix B.
107
+ * @param scale Scaling factor to apply to the result.
108
+ * @param is_transposed Flag indicating if B is already transposed (1 if true, 0 otherwise).
109
+ */
110
+ void matrix_multiply (const float * A , const float * B , float * C ,
111
+ int M , int N , int K , float scale , int is_transposed )
112
+ {
113
+ if (A == NULL || B == NULL || C == NULL )
114
+ {
115
+ fprintf (stderr , "Error: Null pointer passed to matrix_multiply.\n" );
116
+ return ;
117
+ }
118
+
119
+ if (M <= 0 || N <= 0 || K <= 0 )
120
+ {
121
+ fprintf (stderr , "Error: Invalid matrix dimensions passed to matrix_multiply.\n" );
122
+ return ;
123
+ }
124
+
125
+ if (is_transposed )
126
+ {
127
+ matrix_multiply_simd (A , B , C , M , N , K , scale );
128
+ }
129
+ else
130
+ {
131
+ float * B_T = (float * )aligned_alloc (32 , N * K * sizeof (float ));
132
+ if (B_T == NULL )
133
+ {
134
+ fprintf (stderr , "Error: Memory allocation failed for transposed matrix in matrix_multiply.\n" );
135
+ return ;
136
+ }
137
+
138
+ transpose_matrix (B , B_T , K , N );
139
+
140
+ matrix_multiply_simd (A , B_T , C , M , N , K , scale );
141
+
142
+ free (B_T );
143
+ }
144
+ }
0 commit comments