Skip to content

Commit a3c2873

Browse files
authored
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
1 parent 03cf033 commit a3c2873

File tree

22 files changed

+878
-11
lines changed

22 files changed

+878
-11
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import argparse
2+
3+
import matplotlib
4+
import mlx.core as mx
5+
import numpy as np
6+
from time_utils import measure_runtime
7+
8+
matplotlib.use("Agg")
9+
import matplotlib.pyplot as plt
10+
11+
12+
def had(x):
13+
y = mx.hadamard_transform(x)
14+
mx.eval(y)
15+
16+
17+
def copy(x):
18+
y = x + 1.0
19+
mx.eval(y)
20+
21+
22+
def run(dtype):
23+
system_size = 2**26
24+
outputs = {}
25+
for test_fn in (had, copy):
26+
for m in [1, 12, 20, 28]:
27+
if test_fn == copy:
28+
key = "copy"
29+
elif m == 1:
30+
key = "had_2^k"
31+
else:
32+
key = "had_m*2^k"
33+
outputs.setdefault(key, {})
34+
for k in range(7, 14):
35+
n = m * 2**k
36+
if n > 2**15:
37+
continue
38+
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
39+
x = mx.array(x_np)
40+
runtime_ms = measure_runtime(test_fn, x=x)
41+
bytes_per_gb = 1e9
42+
ms_per_s = 1e3
43+
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
44+
bandwidth_gb = (
45+
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
46+
)
47+
print(n, bandwidth_gb)
48+
outputs[key][n] = bandwidth_gb
49+
50+
colors = {
51+
"copy": "black",
52+
"had_2^k": "steelblue",
53+
"had_m*2^k": "skyblue",
54+
}
55+
for key, output in outputs.items():
56+
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
57+
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
58+
plt.xlabel("N")
59+
plt.ylabel("Bandwidth (GB/s)")
60+
plt.legend()
61+
plt.savefig(f"bench_{dtype.__name__}.png")
62+
plt.clf()
63+
64+
65+
if __name__ == "__main__":
66+
parser = argparse.ArgumentParser()
67+
parser.add_argument("--fp16", action="store_true")
68+
args = parser.parse_args()
69+
dtype = np.float16 if args.fp16 else np.float32
70+
run(dtype)

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Operations
7272
gather_qmm
7373
greater
7474
greater_equal
75+
hadamard_transform
7576
identity
7677
inner
7778
isclose

mlx/backend/accelerate/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ DEFAULT(GatherMM)
5050
DEFAULT(GatherQMM)
5151
DEFAULT(Greater)
5252
DEFAULT(GreaterEqual)
53+
DEFAULT(Hadamard)
5354
DEFAULT(Less)
5455
DEFAULT(LessEqual)
5556
DEFAULT(Load)

mlx/backend/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ target_sources(
4242
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
4343
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
4444
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
45+
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
4546
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
4647
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
4748
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp

mlx/backend/common/default_primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ DEFAULT(Full)
6868
DEFAULT(Gather)
6969
DEFAULT(Greater)
7070
DEFAULT(GreaterEqual)
71+
DEFAULT(Hadamard)
7172
DEFAULT(Less)
7273
DEFAULT(LessEqual)
7374
DEFAULT(Load)

mlx/backend/common/hadamard.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#include <cassert>
4+
5+
#include "mlx/backend/common/copy.h"
6+
#include "mlx/backend/common/hadamard.h"
7+
#include "mlx/primitives.h"
8+
9+
namespace mlx::core {
10+
11+
// n = 2^k component
12+
template <typename T>
13+
void hadamard_n(array& out, int n, int m, float scale) {
14+
for (int b = 0; b < out.size() / n; b++) {
15+
size_t loc = b * n;
16+
T* data_ptr = out.data<T>() + loc;
17+
int h = 1;
18+
int n_over_2 = n / 2;
19+
while (h < n) {
20+
for (int i = 0; i < n / 2; i++) {
21+
int k = i & (h - 1);
22+
int j = ((i - k) << 1) + k;
23+
float x = *(data_ptr + j);
24+
float y = *(data_ptr + j + h);
25+
*(data_ptr + j) = x + y;
26+
*(data_ptr + j + h) = x - y;
27+
if (h == n_over_2) {
28+
*(data_ptr + j) *= scale;
29+
*(data_ptr + j + h) *= scale;
30+
}
31+
}
32+
h <<= 1;
33+
}
34+
}
35+
}
36+
37+
// m component
38+
template <typename T>
39+
void hadamard_m(array& out, int n, int m, float scale) {
40+
auto h_matrices = hadamard_matrices();
41+
auto& matrix = h_matrices[m];
42+
auto start = 1;
43+
auto end = matrix.find('\n', start);
44+
std::vector<bool> hmat_vec;
45+
while (end != std::string_view::npos) {
46+
auto row = matrix.substr(start, end - start);
47+
for (int i = 0; i < row.length(); i++) {
48+
hmat_vec.push_back(row[i] == '+');
49+
}
50+
start = end + 1;
51+
end = matrix.find('\n', start);
52+
}
53+
54+
for (int b = 0; b < out.size() / m / n; b++) {
55+
size_t loc = b * n * m;
56+
T* data_ptr = out.data<T>() + loc;
57+
for (int i = 0; i < n; i++) {
58+
std::vector<float> out(m);
59+
for (int j = 0; j < m; j++) {
60+
for (int k = 0; k < m; k++) {
61+
float x = *(data_ptr + i + k * n);
62+
if (hmat_vec[k + j * m]) {
63+
out[j] += x;
64+
} else {
65+
out[j] -= x;
66+
}
67+
}
68+
}
69+
for (int j = 0; j < m; j++) {
70+
*(data_ptr + i + j * n) = out[j] * scale;
71+
}
72+
}
73+
}
74+
}
75+
76+
template <typename T>
77+
void hadamard(array& out, int n, int m, float scale) {
78+
float n_scale = m > 1 ? 1.0 : scale;
79+
hadamard_n<T>(out, n, m, n_scale);
80+
if (m > 1) {
81+
hadamard_m<T>(out, n, m, scale);
82+
}
83+
}
84+
85+
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
86+
assert(inputs.size() == 1);
87+
auto& in = inputs[0];
88+
89+
// Copy input to output
90+
copy(in, out, CopyType::General);
91+
92+
int axis = out.ndim() - 1;
93+
auto [n, m] = decompose_hadamard(out.shape(axis));
94+
95+
switch (in.dtype()) {
96+
case float32:
97+
return hadamard<float>(out, n, m, scale_);
98+
case float16:
99+
return hadamard<float16_t>(out, n, m, scale_);
100+
case bfloat16:
101+
return hadamard<bfloat16_t>(out, n, m, scale_);
102+
default:
103+
throw std::invalid_argument("[hadamard] Unsupported type.");
104+
}
105+
}
106+
107+
} // namespace mlx::core

mlx/backend/common/hadamard.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#pragma once
4+
5+
#include <map>
6+
7+
#include "mlx/utils.h"
8+
9+
namespace mlx::core {
10+
11+
// From http://neilsloane.com/hadamard/
12+
constexpr std::string_view h12 = R"(
13+
+-++++++++++
14+
--+-+-+-+-+-
15+
+++-++----++
16+
+---+--+-++-
17+
+++++-++----
18+
+-+---+--+-+
19+
++--+++-++--
20+
+--++---+--+
21+
++----+++-++
22+
+--+-++---+-
23+
++++----+++-
24+
+-+--+-++---
25+
)";
26+
27+
constexpr std::string_view h20 = R"(
28+
+----+----++--++-++-
29+
-+----+---+++---+-++
30+
--+----+---+++-+-+-+
31+
---+----+---+++++-+-
32+
----+----++--++-++-+
33+
-+++++-----+--+++--+
34+
+-+++-+---+-+--+++--
35+
++-++--+---+-+--+++-
36+
+++-+---+---+-+--+++
37+
++++-----++--+-+--++
38+
--++-+-++-+-----++++
39+
---++-+-++-+---+-+++
40+
+---++-+-+--+--++-++
41+
++---++-+----+-+++-+
42+
-++---++-+----+++++-
43+
-+--+--++-+----+----
44+
+-+-----++-+----+---
45+
-+-+-+---+--+----+--
46+
--+-+++------+----+-
47+
+--+--++------+----+
48+
)";
49+
50+
constexpr std::string_view h28 = R"(
51+
+------++----++-+--+-+--++--
52+
-+-----+++-----+-+--+-+--++-
53+
--+-----+++---+-+-+----+--++
54+
---+-----+++---+-+-+-+--+--+
55+
----+-----+++---+-+-+++--+--
56+
-----+-----++++--+-+--++--+-
57+
------++----++-+--+-+--++--+
58+
--++++-+-------++--+++-+--+-
59+
---++++-+-----+-++--+-+-+--+
60+
+---+++--+----++-++--+-+-+--
61+
++---++---+----++-++--+-+-+-
62+
+++---+----+----++-++--+-+-+
63+
++++--------+-+--++-++--+-+-
64+
-++++--------+++--++--+--+-+
65+
-+-++-++--++--+--------++++-
66+
+-+-++--+--++--+--------++++
67+
-+-+-++--+--++--+----+---+++
68+
+-+-+-++--+--+---+---++---++
69+
++-+-+-++--+------+--+++---+
70+
-++-+-+-++--+------+-++++---
71+
+-++-+---++--+------+-++++--
72+
-++--++-+-++-+++----++------
73+
+-++--++-+-++-+++-----+-----
74+
++-++---+-+-++-+++-----+----
75+
-++-++-+-+-+-+--+++-----+---
76+
--++-++++-+-+----+++-----+--
77+
+--++-+-++-+-+----+++-----+-
78+
++--++-+-++-+-+----++------+
79+
)";
80+
81+
inline const std::map<int, std::string_view> hadamard_matrices() {
82+
return {{12, h12}, {20, h20}, {28, h28}};
83+
}
84+
85+
inline std::pair<int, int> decompose_hadamard(int n) {
86+
// n = m*2^k
87+
int m = 1;
88+
if (!is_power_of_2(n)) {
89+
auto h_matrices = hadamard_matrices();
90+
for (auto [factor, _] : h_matrices) {
91+
if (n % factor == 0) {
92+
m = factor;
93+
n /= factor;
94+
break;
95+
}
96+
}
97+
if (m == 1) {
98+
throw std::invalid_argument(
99+
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
100+
}
101+
}
102+
return {n, m};
103+
}
104+
105+
} // namespace mlx::core

mlx/backend/metal/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ make_jit_source(
5252
)
5353
make_jit_source(scatter)
5454
make_jit_source(gather)
55+
make_jit_source(hadamard)
5556

5657
if (MLX_METAL_JIT)
5758
target_sources(
@@ -132,6 +133,7 @@ target_sources(
132133
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
133134
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
134135
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
136+
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
135137
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
136138
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
137139
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp

mlx/backend/metal/fft.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlx/backend/metal/utils.h"
1515
#include "mlx/mlx.h"
1616
#include "mlx/primitives.h"
17+
#include "mlx/utils.h"
1718

1819
namespace mlx::core {
1920

0 commit comments

Comments
 (0)