Skip to content

Commit dc8eac5

Browse files
committed
Feat : Added a wrapper function for transpose
1 parent 2ecbc1a commit dc8eac5

File tree

2 files changed

+111
-4
lines changed

2 files changed

+111
-4
lines changed

include/Core/matrix_multiply.h

+35-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,37 @@
66
/**
77
* @brief Performs a matrix multiplication using SIMD instructions (AVX).
88
*
9+
* Note: The second matrix (B_T) must be transposed before calling this function.
10+
*
11+
* C = A * B_T * scale
12+
*
13+
* @param A Pointer to the first matrix (M x K).
14+
* @param B_T Pointer to the transposed second matrix (N x K).
15+
* @param C Pointer to the result matrix (M x N).
16+
* @param M Number of rows in matrix A.
17+
* @param N Number of columns in matrix B.
18+
* @param K Number of columns in matrix A and rows in matrix B.
19+
* @param scale Scaling factor to apply to the result.
20+
*/
21+
void matrix_multiply_simd(const float *A, const float *B_T, float *C,
22+
int M, int N, int K, float scale);
23+
24+
/**
25+
* @brief Transposes a matrix from row-major to row-major with flipped dimensions.
26+
* Input: B (K x N), Output: B_T (N x K)
27+
*
28+
* Note: This function validates input pointers and dimensions.
29+
*
30+
* @param B Original matrix
31+
* @param B_T Transposed matrix
32+
* @param K Rows of original B
33+
* @param N Columns of original B
34+
*/
35+
void transpose_matrix(const float *B, float *B_T, int K, int N);
36+
37+
/**
38+
* @brief Performs a matrix multiplication, checking if the second matrix is already transposed.
39+
*
940
* C = A * B * scale
1041
*
1142
* @param A Pointer to the first matrix (M x K).
@@ -15,8 +46,9 @@
1546
* @param N Number of columns in matrix B.
1647
* @param K Number of columns in matrix A and rows in matrix B.
1748
* @param scale Scaling factor to apply to the result.
49+
* @param is_transposed Flag indicating if B is already transposed (1 if true, 0 otherwise).
1850
*/
19-
void matrix_multiply_simd(const float *A, const float *B, float *C,
20-
int M, int N, int K, float scale);
51+
void matrix_multiply(const float *A, const float *B, float *C,
52+
int M, int N, int K, float scale, int is_transposed);
2153

22-
#endif
54+
#endif

src/Core/matrix_multiply.c

+76-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <immintrin.h>
22
#include <omp.h>
33
#include <stdio.h>
4+
#include "../../include/Core/matrix_multiply.h"
45

56
/**
67
* @brief Performs a matrix multiplication using SIMD instructions (AVX).
@@ -18,6 +19,18 @@
1819
void matrix_multiply_simd(const float *A, const float *B_T, float *C,
1920
int M, int N, int K, float scale)
2021
{
22+
if (A == NULL || B_T == NULL || C == NULL)
23+
{
24+
fprintf(stderr, "Error: Null pointer passed to matrix_multiply_simd.\n");
25+
return;
26+
}
27+
28+
if (M <= 0 || N <= 0 || K <= 0)
29+
{
30+
fprintf(stderr, "Error: Invalid matrix dimensions passed to matrix_multiply_simd.\n");
31+
return;
32+
}
33+
2134
#pragma omp parallel for collapse(2)
2235
for (int i = 0; i < M; i++)
2336
{
@@ -29,7 +42,7 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
2942
for (k = 0; k <= K - 8; k += 8)
3043
{
3144
__m256 a = _mm256_loadu_ps(&A[i * K + k]);
32-
__m256 b = _mm256_loadu_ps(&B_T[j * K + k]); // Access row in transposed B
45+
__m256 b = _mm256_loadu_ps(&B_T[j * K + k]);
3346
sum = _mm256_add_ps(sum, _mm256_mul_ps(a, b));
3447
}
3548

@@ -62,8 +75,70 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
6275
*/
6376
void transpose_matrix(const float *B, float *B_T, int K, int N)
6477
{
78+
if (B == NULL || B_T == NULL)
79+
{
80+
fprintf(stderr, "Error: Null pointer passed to transpose_matrix.\n");
81+
return;
82+
}
83+
84+
if (K <= 0 || N <= 0)
85+
{
86+
fprintf(stderr, "Error: Invalid matrix dimensions passed to transpose_matrix.\n");
87+
return;
88+
}
89+
6590
#pragma omp parallel for collapse(2)
6691
for (int i = 0; i < K; ++i)
6792
for (int j = 0; j < N; ++j)
6893
B_T[j * K + i] = B[i * N + j];
6994
}
95+
96+
/**
97+
* @brief Performs a matrix multiplication, checking if the second matrix is already transposed.
98+
*
99+
* C = A * B * scale
100+
*
101+
* @param A Pointer to the first matrix (M x K).
102+
* @param B Pointer to the second matrix (K x N).
103+
* @param C Pointer to the result matrix (M x N).
104+
* @param M Number of rows in matrix A.
105+
* @param N Number of columns in matrix B.
106+
* @param K Number of columns in matrix A and rows in matrix B.
107+
* @param scale Scaling factor to apply to the result.
108+
* @param is_transposed Flag indicating if B is already transposed (1 if true, 0 otherwise).
109+
*/
110+
void matrix_multiply(const float *A, const float *B, float *C,
111+
int M, int N, int K, float scale, int is_transposed)
112+
{
113+
if (A == NULL || B == NULL || C == NULL)
114+
{
115+
fprintf(stderr, "Error: Null pointer passed to matrix_multiply.\n");
116+
return;
117+
}
118+
119+
if (M <= 0 || N <= 0 || K <= 0)
120+
{
121+
fprintf(stderr, "Error: Invalid matrix dimensions passed to matrix_multiply.\n");
122+
return;
123+
}
124+
125+
if (is_transposed)
126+
{
127+
matrix_multiply_simd(A, B, C, M, N, K, scale);
128+
}
129+
else
130+
{
131+
float *B_T = (float *)aligned_alloc(32, N * K * sizeof(float));
132+
if (B_T == NULL)
133+
{
134+
fprintf(stderr, "Error: Memory allocation failed for transposed matrix in matrix_multiply.\n");
135+
return;
136+
}
137+
138+
transpose_matrix(B, B_T, K, N);
139+
140+
matrix_multiply_simd(A, B_T, C, M, N, K, scale);
141+
142+
free(B_T);
143+
}
144+
}

0 commit comments

Comments
 (0)