1+ // Copyright © 2024 Apple Inc.
2+
3+ #include < cassert>
4+
5+ #include " mlx/backend/common/copy.h"
6+ #include " mlx/backend/common/hadamard.h"
7+ #include " mlx/primitives.h"
8+
9+ namespace mlx ::core {
10+
11+ // n = 2^k component
12+ template <typename T>
13+ void hadamard_n (array& out, int n, int m, float scale) {
14+ for (int b = 0 ; b < out.size () / n; b++) {
15+ size_t loc = b * n;
16+ T* data_ptr = out.data <T>() + loc;
17+ int h = 1 ;
18+ int n_over_2 = n / 2 ;
19+ while (h < n) {
20+ for (int i = 0 ; i < n / 2 ; i++) {
21+ int k = i & (h - 1 );
22+ int j = ((i - k) << 1 ) + k;
23+ float x = *(data_ptr + j);
24+ float y = *(data_ptr + j + h);
25+ *(data_ptr + j) = x + y;
26+ *(data_ptr + j + h) = x - y;
27+ if (h == n_over_2) {
28+ *(data_ptr + j) *= scale;
29+ *(data_ptr + j + h) *= scale;
30+ }
31+ }
32+ h <<= 1 ;
33+ }
34+ }
35+ }
36+
37+ // m component
38+ template <typename T>
39+ void hadamard_m (array& out, int n, int m, float scale) {
40+ auto h_matrices = hadamard_matrices ();
41+ auto & matrix = h_matrices[m];
42+ auto start = 1 ;
43+ auto end = matrix.find (' \n ' , start);
44+ std::vector<bool > hmat_vec;
45+ while (end != std::string_view::npos) {
46+ auto row = matrix.substr (start, end - start);
47+ for (int i = 0 ; i < row.length (); i++) {
48+ hmat_vec.push_back (row[i] == ' +' );
49+ }
50+ start = end + 1 ;
51+ end = matrix.find (' \n ' , start);
52+ }
53+
54+ for (int b = 0 ; b < out.size () / m / n; b++) {
55+ size_t loc = b * n * m;
56+ T* data_ptr = out.data <T>() + loc;
57+ for (int i = 0 ; i < n; i++) {
58+ std::vector<float > out (m);
59+ for (int j = 0 ; j < m; j++) {
60+ for (int k = 0 ; k < m; k++) {
61+ float x = *(data_ptr + i + k * n);
62+ if (hmat_vec[k + j * m]) {
63+ out[j] += x;
64+ } else {
65+ out[j] -= x;
66+ }
67+ }
68+ }
69+ for (int j = 0 ; j < m; j++) {
70+ *(data_ptr + i + j * n) = out[j] * scale;
71+ }
72+ }
73+ }
74+ }
75+
76+ template <typename T>
77+ void hadamard (array& out, int n, int m, float scale) {
78+ float n_scale = m > 1 ? 1.0 : scale;
79+ hadamard_n<T>(out, n, m, n_scale);
80+ if (m > 1 ) {
81+ hadamard_m<T>(out, n, m, scale);
82+ }
83+ }
84+
85+ void Hadamard::eval (const std::vector<array>& inputs, array& out) {
86+ assert (inputs.size () == 1 );
87+ auto & in = inputs[0 ];
88+
89+ // Copy input to output
90+ copy (in, out, CopyType::General);
91+
92+ int axis = out.ndim () - 1 ;
93+ auto [n, m] = decompose_hadamard (out.shape (axis));
94+
95+ switch (in.dtype ()) {
96+ case float32:
97+ return hadamard<float >(out, n, m, scale_);
98+ case float16:
99+ return hadamard<float16_t >(out, n, m, scale_);
100+ case bfloat16:
101+ return hadamard<bfloat16_t >(out, n, m, scale_);
102+ default :
103+ throw std::invalid_argument (" [hadamard] Unsupported type." );
104+ }
105+ }
106+
107+ } // namespace mlx::core
0 commit comments