1- # RUN: %PYTHON %s | FileCheck %s
1+ # REQUIRES: mpi4py
2+ # RUN: mpirun -n 4 %PYTHON %s | FileCheck %s
23# CHECK: PASSED
34"""
45A single MLP that can run on multiple MPI ranks,
3031from mpi4py import MPI
3132
3233
34+ if not MPI .Is_initialized ():
35+ MPI .Init ()
36+ P = MPI .COMM_WORLD .Get_size ()
37+ R = MPI .COMM_WORLD .Get_rank ()
38+
39+
40+ def rprint (* args , ** kwargs ):
41+ if R == 0 :
42+ print (* args , ** kwargs )
43+
44+
3345def parse_cla ():
3446 parser = argparse .ArgumentParser (
3547 description = "MLP on MPI using MLIR" ,
@@ -40,7 +52,7 @@ def parse_cla():
4052 "-s" ,
4153 type = int ,
4254 nargs = 3 ,
43- default = [4096 , 4096 , 4096 ],
55+ default = [64 , 128 , 32 ],
4456 help = "M,N,K matrix sizes (Activations=MxK, WeightsIn=KxN, WeightsOut=MxN, Result=MxK)." ,
4557 )
4658 parser .add_argument (
@@ -97,7 +109,7 @@ def __init__(self, args, P: int, R: int):
97109 self .verbose = args .verbose
98110
99111 def _alloc_inout (self , execution_engine : ExecutionEngine ) -> list [ctypes .Structure ]:
100- print (" * Allocating input/output arrays..." )
112+ rprint (" * Allocating input/output arrays..." )
101113 memrefs = [
102114 make_nd_memref_descriptor (2 , as_ctype (self .dtype ))() for _ in range (4 )
103115 ]
@@ -106,7 +118,7 @@ def _alloc_inout(self, execution_engine: ExecutionEngine) -> list[ctypes.Structu
106118 return memrefs
107119
108120 def _init_inout (self , r : np .ndarray , a : np .ndarray , b : np .ndarray , c : np .ndarray ):
109- print (" * Initializing input arrays..." )
121+ rprint (" * Initializing input arrays..." )
110122 np .random .seed (self .R )
111123 # R = ranked_memref_to_numpy([r])
112124 A = ranked_memref_to_numpy ([a ])
@@ -128,7 +140,7 @@ def allocate_inputs(self, execution_engine: ExecutionEngine):
128140 pass
129141
130142 def _reference_solution (self , execution_engine : ExecutionEngine ) -> np .ndarray :
131- print (" * Gathering input data..." )
143+ rprint (" * Gathering input data..." )
132144 gathered = []
133145 for i , v in enumerate (["act" , "win" , "wout" ]):
134146 memref = make_nd_memref_descriptor (2 , as_ctype (self .dtype ))()
@@ -139,7 +151,7 @@ def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray:
139151 )
140152 gathered .append (ranked_memref_to_numpy ([memref ]))
141153
142- print (" * Computing reference solution..." )
154+ rprint (" * Computing reference solution..." )
143155
144156 def sigmoid (z ):
145157 return 1 / (1 + np .exp (- z ))
@@ -153,15 +165,16 @@ def check_correctness(
153165 R = ranked_memref_to_numpy ([self ._input_arrays [0 ]])
154166 R_ref = self ._reference_solution (execution_engine )
155167 if verbose > 1 :
156- print ("Reference solution:" )
157- print (R_ref )
158- print ("Computed solution:" )
159- print (R )
168+ rprint ("Reference solution:" )
169+ rprint (R_ref )
170+ rprint ("Computed solution:" )
171+ rprint (R )
160172 success = np .allclose (R , R_ref )
173+ success = MPI .COMM_WORLD .allreduce (success , op = MPI .LAND )
161174 if success :
162- print ("PASSED" )
175+ rprint ("PASSED" )
163176 else :
164- print ("FAILED Result mismatch!" )
177+ rprint ("FAILED Result mismatch!" )
165178 return success
166179
167180 def shared_libs (self ) -> list [str ]:
@@ -182,7 +195,7 @@ def get_complexity(self) -> tuple[int, int, int]:
182195
183196 def payload_module (self ) -> ir .Module :
184197 if self .griddims == 1 :
185- print (f"Using 1D grid of size { self .P } " )
198+ rprint (f"Using 1D grid of size { self .P } " )
186199 grid = self .P
187200 elif self .griddims == 2 :
188201 # find two factors of P that are as close as possible
@@ -193,14 +206,14 @@ def find_factors(n):
193206 return (1 , n )
194207
195208 p1 , p2 = find_factors (self .P )
196- print (f"Using 2D grid of size { p1 } x{ p2 } " )
209+ rprint (f"Using 2D grid of size { p1 } x{ p2 } " )
197210 grid = f"{ p1 } x{ p2 } "
198211 else :
199212 raise ValueError (
200213 f"Only 1D and 2D grids are supported (not { self .griddims } d).\n "
201214 )
202215
203- fname = "mlp_weight_stationary.mlir"
216+ fname = Path ( __file__ ). parent / "mlp_weight_stationary.mlir"
204217 with open (fname , "r" ) as f :
205218 txt = f .read ()
206219
@@ -247,10 +260,10 @@ def find_factors(n):
247260 txt = txt .format_map (format_values )
248261
249262 if self .verbose > 1 :
250- print ("Payload MLIR:" )
263+ rprint ("Payload MLIR:" )
251264 count = 1
252265 for line in txt .splitlines ():
253- print (str (count ) + "\t " + line )
266+ rprint (str (count ) + "\t " + line )
254267 count += 1
255268
256269 return ir .Module .parse (txt )
@@ -340,22 +353,22 @@ def schedule_module(
340353 with ir .Context (), ir .Location .unknown ():
341354 wload = DistMLP (args , P , R )
342355
343- print (" Execute" .center (60 , "-" ))
356+ rprint (" Execute" .center (60 , "-" ))
344357 execute (wload , verbose = args .verbose )
345358
346- # print (" Execute 2 ".center(60, "-"))
359+ # rprint (" Execute 2 ".center(60, "-"))
347360 # execute(wload, verbose=1)
348361
349- # print (" Benchmark ".center(60, "-"))
362+ # rprint (" Benchmark ".center(60, "-"))
350363 # times = benchmark(wload)
351364 # times *= 1e6 # convert to microseconds
352365 # compute statistics
353366 # mean = np.mean(times)
354367 # min = np.min(times)
355368 # max = np.max(times)
356369 # std = np.std(times)
357- # print (f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
370+ # rprint (f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
358371 # flop_count = wload.get_complexity()[0]
359372 # gflops = flop_count / (mean * 1e-6) / 1e9
360- # print (f"Throughput: {gflops:.2f} GFLOPS")
373+ # rprint (f"Throughput: {gflops:.2f} GFLOPS")
361374 MPI .Finalize ()
0 commit comments