Skip to content

Commit 3f41e3c

Browse files
committed
enabling CI for MPI tests, cleanup
1 parent fa096b1 commit 3f41e3c

File tree

5 files changed

+49
-30
lines changed

5 files changed

+49
-30
lines changed

.github/workflows/examples.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,9 @@ jobs:
3333
run: |
3434
export FILECHECK=FileCheck-18 # Ubuntu's llvm-dev appends a version number.
3535
uv run lit examples --verbose # Makes sure to substitute FileCheck for $FILECHECK
36+
37+
- name: Run lit-enabled examples which use mpi as tests
38+
run: |
39+
export FILECHECK=FileCheck-18 # Ubuntu's llvm-dev appends a version number.
40+
uv sync --extra runtime_mpi
41+
uv run lit examples/mlp-mpi --verbose # Makes sure to substitute FileCheck for $FILECHECK

examples/mlp-mpi/mlp-mpi.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# RUN: %PYTHON %s | FileCheck %s
1+
# REQUIRES: mpi4py
2+
# RUN: mpirun -n 4 %PYTHON %s | FileCheck %s
23
# CHECK: PASSED
34
"""
45
A single MLP that can run on multiple MPI ranks,
@@ -30,6 +31,17 @@
3031
from mpi4py import MPI
3132

3233

34+
if not MPI.Is_initialized():
35+
MPI.Init()
36+
P = MPI.COMM_WORLD.Get_size()
37+
R = MPI.COMM_WORLD.Get_rank()
38+
39+
40+
def rprint(*args, **kwargs):
41+
if R == 0:
42+
print(*args, **kwargs)
43+
44+
3345
def parse_cla():
3446
parser = argparse.ArgumentParser(
3547
description="MLP on MPI using MLIR",
@@ -40,7 +52,7 @@ def parse_cla():
4052
"-s",
4153
type=int,
4254
nargs=3,
43-
default=[4096, 4096, 4096],
55+
default=[64, 128, 32],
4456
help="M,N,K matrix sizes (Activations=MxK, WeightsIn=KxN, WeightsOut=MxN, Result=MxK).",
4557
)
4658
parser.add_argument(
@@ -97,7 +109,7 @@ def __init__(self, args, P: int, R: int):
97109
self.verbose = args.verbose
98110

99111
def _alloc_inout(self, execution_engine: ExecutionEngine) -> list[ctypes.Structure]:
100-
print(" * Allocating input/output arrays...")
112+
rprint(" * Allocating input/output arrays...")
101113
memrefs = [
102114
make_nd_memref_descriptor(2, as_ctype(self.dtype))() for _ in range(4)
103115
]
@@ -106,7 +118,7 @@ def _alloc_inout(self, execution_engine: ExecutionEngine) -> list[ctypes.Structu
106118
return memrefs
107119

108120
def _init_inout(self, r: np.ndarray, a: np.ndarray, b: np.ndarray, c: np.ndarray):
109-
print(" * Initializing input arrays...")
121+
rprint(" * Initializing input arrays...")
110122
np.random.seed(self.R)
111123
# R = ranked_memref_to_numpy([r])
112124
A = ranked_memref_to_numpy([a])
@@ -128,7 +140,7 @@ def allocate_inputs(self, execution_engine: ExecutionEngine):
128140
pass
129141

130142
def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray:
131-
print(" * Gathering input data...")
143+
rprint(" * Gathering input data...")
132144
gathered = []
133145
for i, v in enumerate(["act", "win", "wout"]):
134146
memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))()
@@ -139,7 +151,7 @@ def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray:
139151
)
140152
gathered.append(ranked_memref_to_numpy([memref]))
141153

142-
print(" * Computing reference solution...")
154+
rprint(" * Computing reference solution...")
143155

144156
def sigmoid(z):
145157
return 1 / (1 + np.exp(-z))
@@ -153,15 +165,16 @@ def check_correctness(
153165
R = ranked_memref_to_numpy([self._input_arrays[0]])
154166
R_ref = self._reference_solution(execution_engine)
155167
if verbose > 1:
156-
print("Reference solution:")
157-
print(R_ref)
158-
print("Computed solution:")
159-
print(R)
168+
rprint("Reference solution:")
169+
rprint(R_ref)
170+
rprint("Computed solution:")
171+
rprint(R)
160172
success = np.allclose(R, R_ref)
173+
success = MPI.COMM_WORLD.allreduce(success, op=MPI.LAND)
161174
if success:
162-
print("PASSED")
175+
rprint("PASSED")
163176
else:
164-
print("FAILED Result mismatch!")
177+
rprint("FAILED Result mismatch!")
165178
return success
166179

167180
def shared_libs(self) -> list[str]:
@@ -182,7 +195,7 @@ def get_complexity(self) -> tuple[int, int, int]:
182195

183196
def payload_module(self) -> ir.Module:
184197
if self.griddims == 1:
185-
print(f"Using 1D grid of size {self.P}")
198+
rprint(f"Using 1D grid of size {self.P}")
186199
grid = self.P
187200
elif self.griddims == 2:
188201
# find two factors of P that are as close as possible
@@ -193,14 +206,14 @@ def find_factors(n):
193206
return (1, n)
194207

195208
p1, p2 = find_factors(self.P)
196-
print(f"Using 2D grid of size {p1}x{p2}")
209+
rprint(f"Using 2D grid of size {p1}x{p2}")
197210
grid = f"{p1}x{p2}"
198211
else:
199212
raise ValueError(
200213
f"Only 1D and 2D grids are supported (not {self.griddims}d).\n"
201214
)
202215

203-
fname = "mlp_weight_stationary.mlir"
216+
fname = Path(__file__).parent / "mlp_weight_stationary.mlir"
204217
with open(fname, "r") as f:
205218
txt = f.read()
206219

@@ -247,10 +260,10 @@ def find_factors(n):
247260
txt = txt.format_map(format_values)
248261

249262
if self.verbose > 1:
250-
print("Payload MLIR:")
263+
rprint("Payload MLIR:")
251264
count = 1
252265
for line in txt.splitlines():
253-
print(str(count) + "\t" + line)
266+
rprint(str(count) + "\t" + line)
254267
count += 1
255268

256269
return ir.Module.parse(txt)
@@ -340,22 +353,22 @@ def schedule_module(
340353
with ir.Context(), ir.Location.unknown():
341354
wload = DistMLP(args, P, R)
342355

343-
print(" Execute".center(60, "-"))
356+
rprint(" Execute".center(60, "-"))
344357
execute(wload, verbose=args.verbose)
345358

346-
# print(" Execute 2 ".center(60, "-"))
359+
# rprint(" Execute 2 ".center(60, "-"))
347360
# execute(wload, verbose=1)
348361

349-
# print(" Benchmark ".center(60, "-"))
362+
# rprint(" Benchmark ".center(60, "-"))
350363
# times = benchmark(wload)
351364
# times *= 1e6 # convert to microseconds
352365
# compute statistics
353366
# mean = np.mean(times)
354367
# min = np.min(times)
355368
# max = np.max(times)
356369
# std = np.std(times)
357-
# print(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
370+
# rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
358371
# flop_count = wload.get_complexity()[0]
359372
# gflops = flop_count / (mean * 1e-6) / 1e9
360-
# print(f"Throughput: {gflops:.2f} GFLOPS")
373+
# rprint(f"Throughput: {gflops:.2f} GFLOPS")
361374
MPI.Finalize()

examples/mlp-mpi/mlp_weight_stationary.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co
8484
return %ret_a : tensor<{M}x{K}xf32>
8585
}}
8686

87-
// func.func @gather(%t:tensor<5x3xi32>) -> tensor<5x12xi32> attributes {{llvm.emit_c_interface}} {{
88-
// %r = shard.all_gather %t on @grid0 grid_axes = [0] gather_axis = 1 : tensor<5x3xi32> -> tensor<5x12xi32>
89-
// return %r : tensor<5x12xi32>
90-
// }}
91-
9287
func.func @gather_act(%arg0: tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32> attributes {{llvm.emit_c_interface}} {{
9388
%sharding = shard.sharding @grid0 split_axes = {split_act} : !shard.sharding
9489
%sharding_g = shard.sharding @grid0 split_axes = [[]] : !shard.sharding

lit.cfg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
if filecheck_path := os.environ.get("FILECHECK"):
2121
config.substitutions.append(("FileCheck", filecheck_path))
2222

23-
if importlib.util.find_spec("torch"):
24-
config.available_features.add("torch")
23+
for pkg in ["torch", "mpi4py", "mpich", "openmpi", "impi-rt"]:
24+
if importlib.util.find_spec(pkg):
25+
config.available_features.add(pkg)
2526

2627
torch_kernels_dir = project_root + "/third_party/KernelBench/KernelBench"
2728
if os.path.isdir(torch_kernels_dir):

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dev = [
1717

1818
[project.optional-dependencies]
1919
ingress_torch_mlir = [
20-
"torch-mlir==20260125.703",
20+
"torch-mlir==20260209.718",
2121
"ml_dtypes",
2222
]
2323
# Additional "targets" which pull in optional dependencies -- use `uv sync --extra TARGET`
@@ -39,6 +39,10 @@ ingress_torch_xpu = [
3939
"pytorch_triton_xpu", # Transitive dependency listed explicitly so that we can state which package repository it is supposed to come from
4040
"lighthouse[ingress_torch_mlir]"
4141
]
42+
runtime_mpi = [
43+
"mpi4py",
44+
"impi-rt"
45+
]
4246

4347
[tool.uv]
4448
# Declare that the following "targets" are mutually exclusive of one another

0 commit comments

Comments
 (0)