Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/mlp-mpi/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
config.excludes = ["mlp_weight_stationary.py"]
135 changes: 71 additions & 64 deletions examples/mlp-mpi/mlp-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import argparse
import ctypes
from pathlib import Path
from contextlib import contextmanager
from typing import Optional

Expand All @@ -33,6 +32,8 @@
from lighthouse.utils.mlir import apply_registered_pass, match
from lighthouse.workload import Workload, execute

from mlp_weight_stationary import generate_mlp_payload

from mpi4py import MPI


Expand Down Expand Up @@ -146,17 +147,26 @@ def allocate_inputs(self, execution_engine: ExecutionEngine):
):
yield self.input_memrefs

def _gather(
self,
memref: ctypes.Structure,
execution_engine: ExecutionEngine,
gather_func: str,
) -> ctypes.Structure:
gathered_memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))()
execution_engine.invoke(
gather_func,
memref_to_ctype(gathered_memref),
memref_to_ctype(memref),
)
return gathered_memref

def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray:
rprint(" * Gathering input data...")
gathered = []
for i, v in enumerate(["act", "win", "wout"]):
memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))()
execution_engine.invoke(
f"gather_{v}",
memref_to_ctype(memref),
memref_to_ctype(self.input_memrefs[i + 1]),
)
gathered.append(memref)
gathered = [
self._gather(self.input_memrefs[i + 1], execution_engine, f"gather_{v}")
for i, v in enumerate(["act", "win", "wout"])
]

rprint(" * Computing reference solution...")

Expand All @@ -171,14 +181,16 @@ def sigmoid(z):
def check_correctness(
self, execution_engine: ExecutionEngine, verbose: int = 0
) -> bool:
R = ranked_memref_to_numpy([self.input_memrefs[0]])
R_ref = self._reference_solution(execution_engine)
if verbose > 1:
rprint("Reference solution:")
rprint(R_ref)
rprint("Computed solution:")
rprint(R)
success = np.allclose(R, R_ref)
gathered = self._gather(self.input_memrefs[0], execution_engine, "gather_act")
with deallocate_memrefs_on_exit([gathered], execution_engine, "dealloc_2d"):
R = ranked_memref_to_numpy([gathered])
R_ref = self._reference_solution(execution_engine)
if verbose > 1:
rprint("Reference solution:")
rprint(R_ref)
rprint("Computed solution:")
rprint(R)
success = np.allclose(R, R_ref)
success = MPI.COMM_WORLD.allreduce(success, op=MPI.LAND)
if success:
rprint("PASSED")
Expand All @@ -201,7 +213,7 @@ def get_complexity(self) -> tuple[int, int, int]:
def payload_module(self) -> ir.Module:
if len(self.grid) == 1:
rprint(f"Using 1D grid of size {self.comm_size}")
grid = self.comm_size
grid = [self.comm_size]
else:
assert len(self.grid) == 2
if all(x != 0 for x in self.grid):
Expand All @@ -216,62 +228,46 @@ def find_factors(n):

p1, p2 = find_factors(self.comm_size)
rprint(f"Using 2D grid of size {p1}x{p2}")
grid = f"{p1}x{p2}"

fname = Path(__file__).parent / "mlp_weight_stationary.mlir"
with open(fname, "r") as f:
txt = f.read()

format_values = {
"func_name": self.payload_function_name,
"M": self.M,
"N": self.N,
"K": self.K,
"P": self.comm_size,
"R": self.comm_rank,
"grid": grid,
"split_r": "[[]]",
}
grid = [p1, p2]

common = dict(
func_name=self.payload_function_name,
M=self.M,
N=self.N,
K=self.K,
comm_size=self.comm_size,
comm_rank=self.comm_rank,
grid=grid,
)
if len(self.grid) == 1:
format_values.update(
{
"split_act": "[[], [0]]",
"split_win": "[[], [0]]",
"split_wout": "[[0], []]",
"split_mm0_a": "[[]]",
"split_mm0_b": "[[], [0]]",
"split_mm0_c": "[[], [0]]",
"split_sigmoid": "[[], [0]]",
"split_mm1_a": "[[], [0]]",
"split_mm1_b": "[[0], []]",
"split_mm1_c": "[[]]",
}
mod = generate_mlp_payload(
**common,
split_act=[[], [0]],
split_win=[[], [0]],
split_wout=[[0], []],
split_mm0a_mm1c=[[]],
split_mm0_c=[[], [0]],
split_sigmoid=[[], [0]],
)
else:
format_values.update(
{
"split_act": "[[], [0, 1]]",
"split_win": "[[0], [1]]",
"split_wout": "[[1], [0]]",
"split_mm0_a": "[[], [0]]",
"split_mm0_b": "[[0], [1]]",
"split_mm0_c": "[[], [1]]",
"split_sigmoid": "[[], [1, 0]]",
"split_mm1_a": "[[], [1]]",
"split_mm1_b": "[[1], [0]]",
"split_mm1_c": "[[], [0]]",
}
mod = generate_mlp_payload(
**common,
split_act=[[], [0, 1]],
split_win=[[0], [1]],
split_wout=[[1], [0]],
split_mm0a_mm1c=[[], [0]],
split_mm0_c=[[], [1]],
split_sigmoid=[[], [1, 0]],
)
txt = txt.format_map(format_values)

if self.verbose > 1:
rprint("Payload MLIR:")
count = 1
for line in txt.splitlines():
for line in str(mod).splitlines():
rprint(str(count) + "\t" + line)
count += 1

return ir.Module.parse(txt)
return mod

def schedule_module(
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
Expand All @@ -290,10 +286,21 @@ def schedule_module(
with ir.InsertionPoint(named_sequence.body):
anytype = transform.AnyOpType.get()
func = match(named_sequence.bodyTarget, ops={"func.func"})
func = apply_registered_pass(
func,
"sharding-propagation",
options={"traversal": "forward-backward"},
)
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "shard-partition")
func = apply_registered_pass(func, "canonicalize")
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "convert-shard-to-mpi")
func = apply_registered_pass(func, "canonicalize")
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "tosa-to-linalg")
mod = transform.get_parent_op(
anytype, func, op_name="builtin.module", deduplicate=True
Expand Down
113 changes: 0 additions & 113 deletions examples/mlp-mpi/mlp_weight_stationary.mlir

This file was deleted.

Loading