Skip to content

Commit 27d70c7

Browse files
barronalexAlex Barron
andauthored
Feature complete Metal FFT (#1102)
* feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron <[email protected]>
1 parent 0e585b4 commit 27d70c7

File tree

17 files changed

+2612
-378
lines changed

17 files changed

+2612
-378
lines changed

benchmarks/python/fft_bench.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import matplotlib
44
import mlx.core as mx
55
import numpy as np
6+
import sympy
7+
import torch
68
from time_utils import measure_runtime
79

810
matplotlib.use("Agg")
@@ -16,40 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
1618
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
1719

1820

19-
def run_bench(system_size):
20-
def fft(x):
21-
out = mx.fft.fft(x)
21+
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
22+
def fft_mlx(x):
23+
if dim == 1:
24+
out = mx.fft.fft(x)
25+
elif dim == 2:
26+
out = mx.fft.fft2(x)
2227
mx.eval(out)
2328
return out
2429

30+
def fft_mps(x):
31+
if dim == 1:
32+
out = torch.fft.fft(x)
33+
elif dim == 2:
34+
out = torch.fft.fft2(x)
35+
torch.mps.synchronize()
36+
return out
37+
2538
bandwidths = []
26-
for k in range(4, 12):
27-
n = 2**k
28-
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
29-
x = x.astype(mx.complex64)
30-
mx.eval(x)
39+
for n in fft_sizes:
40+
batch_size = system_size // n**dim
41+
shape = [batch_size] + [n for _ in range(dim)]
42+
if backend == "mlx":
43+
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
44+
x = mx.array(x_np)
45+
mx.eval(x)
46+
fft = fft_mlx
47+
elif backend == "mps":
48+
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
49+
x = torch.tensor(x_np, device="mps")
50+
torch.mps.synchronize()
51+
fft = fft_mps
52+
else:
53+
raise NotImplementedError()
3154
runtime_ms = measure_runtime(fft, x=x)
32-
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
55+
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
56+
print(n, bandwidth)
57+
bandwidths.append(bandwidth)
3358

34-
return bandwidths
59+
return np.array(bandwidths)
3560

3661

3762
def time_fft():
38-
with mx.stream(mx.cpu):
39-
cpu_bandwidths = run_bench(system_size=int(2**22))
63+
x = np.array(range(2, 512))
64+
system_size = int(2**26)
4065

66+
print("MLX GPU")
4167
with mx.stream(mx.gpu):
42-
gpu_bandwidths = run_bench(system_size=int(2**29))
43-
44-
# plot bandwidths
45-
x = [2**k for k in range(4, 12)]
46-
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
47-
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
48-
plt.title("MLX FFT Benchmark")
49-
plt.xlabel("N")
50-
plt.ylabel("Bandwidth (GB/s)")
51-
plt.legend()
52-
plt.savefig("fft_plot.png")
68+
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
69+
70+
print("MPS GPU")
71+
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
72+
73+
print("CPU")
74+
system_size = int(2**20)
75+
with mx.stream(mx.cpu):
76+
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
77+
78+
x = np.array(x)
79+
80+
all_indices = x - x[0]
81+
radix_2to13 = (
82+
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
83+
)
84+
bluesteins = (
85+
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
86+
)
87+
88+
for indices, name in [
89+
(all_indices, "All"),
90+
(radix_2to13, "Radix 2-13"),
91+
(bluesteins, "Bluestein's"),
92+
]:
93+
# plot bandwidths
94+
print(name)
95+
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
96+
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
97+
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
98+
plt.title(f"MLX FFT Benchmark -- {name}")
99+
plt.xlabel("N")
100+
plt.ylabel("Bandwidth (GB/s)")
101+
plt.legend()
102+
plt.savefig(f"{name}.png")
103+
plt.clf()
104+
105+
av_gpu_bandwidth = np.mean(gpu_bandwidths)
106+
av_mps_bandwidth = np.mean(mps_bandwidths)
107+
av_cpu_bandwidth = np.mean(cpu_bandwidths)
108+
print("Average bandwidths:")
109+
print("GPU:", av_gpu_bandwidth)
110+
print("MPS:", av_mps_bandwidth)
111+
print("CPU:", av_cpu_bandwidth)
112+
113+
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
114+
print("Percent MLX faster than MPS: ", portion_faster * 100)
53115

54116

55117
if __name__ == "__main__":

mlx/backend/metal/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ if (MLX_METAL_JIT)
6464
make_jit_source(unary)
6565
make_jit_source(binary)
6666
make_jit_source(binary_two)
67+
make_jit_source(
68+
fft
69+
kernels/fft/radix.h
70+
kernels/fft/readwrite.h
71+
)
6772
make_jit_source(ternary)
6873
make_jit_source(softmax)
6974
make_jit_source(scan)

0 commit comments

Comments
 (0)