33import matplotlib
44import mlx .core as mx
55import numpy as np
6+ import sympy
7+ import torch
68from time_utils import measure_runtime
79
810matplotlib .use ("Agg" )
@@ -16,40 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
1618 return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
1719
1820
19- def run_bench (system_size ):
20- def fft (x ):
21- out = mx .fft .fft (x )
21+ def run_bench (system_size , fft_sizes , backend = "mlx" , dim = 1 ):
22+ def fft_mlx (x ):
23+ if dim == 1 :
24+ out = mx .fft .fft (x )
25+ elif dim == 2 :
26+ out = mx .fft .fft2 (x )
2227 mx .eval (out )
2328 return out
2429
30+ def fft_mps (x ):
31+ if dim == 1 :
32+ out = torch .fft .fft (x )
33+ elif dim == 2 :
34+ out = torch .fft .fft2 (x )
35+ torch .mps .synchronize ()
36+ return out
37+
2538 bandwidths = []
26- for k in range (4 , 12 ):
27- n = 2 ** k
28- x = mx .random .uniform (shape = (system_size // n , n )).astype (mx .float32 )
29- x = x .astype (mx .complex64 )
30- mx .eval (x )
39+ for n in fft_sizes :
40+ batch_size = system_size // n ** dim
41+ shape = [batch_size ] + [n for _ in range (dim )]
42+ if backend == "mlx" :
43+ x_np = np .random .uniform (size = (system_size // n , n )).astype (np .complex64 )
44+ x = mx .array (x_np )
45+ mx .eval (x )
46+ fft = fft_mlx
47+ elif backend == "mps" :
48+ x_np = np .random .uniform (size = (system_size // n , n )).astype (np .complex64 )
49+ x = torch .tensor (x_np , device = "mps" )
50+ torch .mps .synchronize ()
51+ fft = fft_mps
52+ else :
53+ raise NotImplementedError ()
3154 runtime_ms = measure_runtime (fft , x = x )
32- bandwidths .append (bandwidth_gb (runtime_ms , system_size ))
55+ bandwidth = bandwidth_gb (runtime_ms , np .prod (shape ))
56+ print (n , bandwidth )
57+ bandwidths .append (bandwidth )
3358
34- return bandwidths
59+ return np . array ( bandwidths )
3560
3661
3762def time_fft ():
38- with mx . stream ( mx . cpu ):
39- cpu_bandwidths = run_bench ( system_size = int (2 ** 22 ) )
63+ x = np . array ( range ( 2 , 512 ))
64+ system_size = int (2 ** 26 )
4065
66+ print ("MLX GPU" )
4167 with mx .stream (mx .gpu ):
42- gpu_bandwidths = run_bench (system_size = int (2 ** 29 ))
43-
44- # plot bandwidths
45- x = [2 ** k for k in range (4 , 12 )]
46- plt .scatter (x , gpu_bandwidths , color = "green" , label = "GPU" )
47- plt .scatter (x , cpu_bandwidths , color = "red" , label = "CPU" )
48- plt .title ("MLX FFT Benchmark" )
49- plt .xlabel ("N" )
50- plt .ylabel ("Bandwidth (GB/s)" )
51- plt .legend ()
52- plt .savefig ("fft_plot.png" )
68+ gpu_bandwidths = run_bench (system_size = system_size , fft_sizes = x )
69+
70+ print ("MPS GPU" )
71+ mps_bandwidths = run_bench (system_size = system_size , fft_sizes = x , backend = "mps" )
72+
73+ print ("CPU" )
74+ system_size = int (2 ** 20 )
75+ with mx .stream (mx .cpu ):
76+ cpu_bandwidths = run_bench (system_size = system_size , fft_sizes = x )
77+
78+ x = np .array (x )
79+
80+ all_indices = x - x [0 ]
81+ radix_2to13 = (
82+ np .array ([i for i in x if all (p <= 13 for p in sympy .primefactors (i ))]) - x [0 ]
83+ )
84+ bluesteins = (
85+ np .array ([i for i in x if any (p > 13 for p in sympy .primefactors (i ))]) - x [0 ]
86+ )
87+
88+ for indices , name in [
89+ (all_indices , "All" ),
90+ (radix_2to13 , "Radix 2-13" ),
91+ (bluesteins , "Bluestein's" ),
92+ ]:
93+ # plot bandwidths
94+ print (name )
95+ plt .scatter (x [indices ], gpu_bandwidths [indices ], color = "green" , label = "GPU" )
96+ plt .scatter (x [indices ], mps_bandwidths [indices ], color = "blue" , label = "MPS" )
97+ plt .scatter (x [indices ], cpu_bandwidths [indices ], color = "red" , label = "CPU" )
98+ plt .title (f"MLX FFT Benchmark -- { name } " )
99+ plt .xlabel ("N" )
100+ plt .ylabel ("Bandwidth (GB/s)" )
101+ plt .legend ()
102+ plt .savefig (f"{ name } .png" )
103+ plt .clf ()
104+
105+ av_gpu_bandwidth = np .mean (gpu_bandwidths )
106+ av_mps_bandwidth = np .mean (mps_bandwidths )
107+ av_cpu_bandwidth = np .mean (cpu_bandwidths )
108+ print ("Average bandwidths:" )
109+ print ("GPU:" , av_gpu_bandwidth )
110+ print ("MPS:" , av_mps_bandwidth )
111+ print ("CPU:" , av_cpu_bandwidth )
112+
113+ portion_faster = len (np .where (gpu_bandwidths > mps_bandwidths )[0 ]) / len (x )
114+ print ("Percent MLX faster than MPS: " , portion_faster * 100 )
53115
54116
55117if __name__ == "__main__" :
0 commit comments