Skip to content

Commit f41a8fa

Browse files
committed
Unify search, compilation, and hardware benchmark into single pipeline with JSON reporting
1 parent b5c95ba commit f41a8fa

File tree

17 files changed

+995
-322
lines changed

17 files changed

+995
-322
lines changed

examples/gym.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77

88
import argparse
99
import logging
10-
import math
1110
from pathlib import Path
1211

1312
import numpy as np
1413

1514
import nkigym
16-
from nkigym.search import benchmark_variants, search
15+
from nkigym.search import search
1716
from nkigym.transforms import DataReuseTransform, OperandMergeTransform
18-
from nkigym.utils import setup_logging
19-
20-
logger = logging.getLogger(__name__)
2117

2218

2319
def nkigym_matmul(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray:
@@ -37,45 +33,32 @@ def parse_args() -> argparse.Namespace:
3733
"""Parse command-line arguments."""
3834
parser = argparse.ArgumentParser(description="NKI Gym search example")
3935
parser.add_argument(
40-
"--cache-dir", type=Path, default=Path("cache"), help="Directory for storing output logs (default: cache)"
36+
"--cache-dir", type=Path, default=Path("cache"), help="Directory for storing output (default: cache)"
4137
)
4238
return parser.parse_args()
4339

4440

4541
def main() -> None:
4642
"""Run transform search on a tiled matmul workload."""
43+
logging.basicConfig(level=logging.INFO, format="%(message)s")
44+
4745
args = parse_args()
4846
cache_dir = args.cache_dir
49-
cache_dir.mkdir(parents=True, exist_ok=True)
50-
log_path = cache_dir / "gym.log"
51-
setup_logging(str(log_path))
5247

5348
k, m, n = 256, 256, 256
5449
rng = np.random.default_rng(42)
5550
lhs = rng.standard_normal((k, m)).astype(np.float32)
5651
rhs = rng.standard_normal((k, n)).astype(np.float32)
5752

58-
variants = search(
53+
search(
5954
func=nkigym_matmul,
6055
transforms=[DataReuseTransform(), OperandMergeTransform()],
61-
num_targets=math.inf,
56+
num_targets=1000,
6257
seed=42,
6358
min_depth=10,
6459
save_cache=cache_dir,
6560
kernel_kwargs={"lhs": lhs, "rhs": rhs},
6661
)
67-
logger.info("Search produced %d unique variants", len(variants))
68-
69-
results = benchmark_variants(
70-
cache_dir=cache_dir,
71-
func_name="nkigym_matmul",
72-
kernel_kwargs={"lhs": lhs, "rhs": rhs},
73-
output_name="output",
74-
output_shape=(m, n),
75-
warmup=2,
76-
iters=5,
77-
)
78-
results.summary(top_k=5)
7962

8063

8164
if __name__ == "__main__":

nkigym/src/nkigym/codegen/context.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ class _LoweringContext:
1919
params: Input parameter names (all live in HBM).
2020
buffers: Variable name to buffer location string.
2121
aliases: Maps accumulation output names to canonical PSUM variable.
22+
alias_offsets: Maps alias names to their start offsets per axis.
2223
staging_counter: Monotonic counter for staging variable names.
2324
"""
2425

2526
params: tuple[str, ...]
2627
buffers: dict[str, str] = field(default_factory=dict)
2728
aliases: dict[str, str] = field(default_factory=dict)
29+
alias_offsets: dict[str, tuple[int, ...]] = field(default_factory=dict)
2830
staging_counter: int = 0
2931

3032
def resolve(self, name: str) -> str:
@@ -40,6 +42,26 @@ def resolve(self, name: str) -> str:
4042
name = self.aliases[name]
4143
return name
4244

45+
def _resolve_offsets(self, name: str) -> tuple[int, ...]:
46+
"""Accumulate start offsets along the alias chain.
47+
48+
Args:
49+
name: Variable name, possibly an accumulation alias.
50+
51+
Returns:
52+
Tuple of accumulated start offsets per axis.
53+
"""
54+
offsets: list[int] = []
55+
while name in self.aliases:
56+
entry_offsets = self.alias_offsets.get(name, ())
57+
if not offsets:
58+
offsets = list(entry_offsets)
59+
else:
60+
for i, o in enumerate(entry_offsets):
61+
offsets[i] += o
62+
name = self.aliases[name]
63+
return tuple(offsets)
64+
4365
def buffer_of(self, name: str) -> str:
4466
"""Look up the buffer location of a variable, resolving aliases.
4567
@@ -57,8 +79,9 @@ def buffer_of(self, name: str) -> str:
5779
def subscript(self, ref: TensorRef) -> str:
5880
"""Render a TensorRef as ``name[s:e, s:e]``, resolving aliases.
5981
60-
Unconditionally renders slices from the IR. The IR is the
61-
source of truth — no shape comparison or optimization.
82+
When the name resolves through an alias chain, accumulates
83+
start offsets and composes them with the ref slices so the
84+
subscript points at the correct region of the canonical buffer.
6285
6386
Args:
6487
ref: Tensor reference.
@@ -67,25 +90,61 @@ def subscript(self, ref: TensorRef) -> str:
6790
Subscripted string or plain resolved name.
6891
"""
6992
resolved = self.resolve(ref.name)
93+
offsets = self._resolve_offsets(ref.name)
7094
result = resolved
7195
if ref.slices:
72-
parts = ", ".join(f"{s}:{e}" for s, e in ref.slices)
96+
parts = _compose_slices(ref.slices, offsets)
7397
result = f"{resolved}[{parts}]"
7498
return result
7599

76100

101+
def _compose_slices(slices: tuple[tuple[int, int], ...], offsets: tuple[int, ...]) -> str:
102+
"""Compose ref slices with alias offsets into a subscript string.
103+
104+
Args:
105+
slices: Per-axis (start, stop) bounds from the TensorRef.
106+
offsets: Per-axis start offsets from the alias chain.
107+
108+
Returns:
109+
Comma-separated ``s:e`` subscript string.
110+
"""
111+
parts: list[str] = []
112+
for i, (s, e) in enumerate(slices):
113+
offset = offsets[i] if i < len(offsets) else 0
114+
parts.append(f"{s + offset}:{e + offset}")
115+
return ", ".join(parts)
116+
117+
77118
def get_kwarg(stmt: GymStatement, key: str) -> object:
78119
"""Extract a keyword argument value from a statement.
79120
121+
Asserts that kwargs contain no duplicate keys, since duplicates
122+
indicate an IR construction bug upstream.
123+
80124
Args:
81125
stmt: GymStatement to search.
82126
key: Keyword argument name.
83127
84128
Returns:
85129
The value if found, None otherwise.
86130
"""
131+
_assert_no_duplicate_kwargs(stmt)
87132
result = None
88133
for k, v in stmt.kwargs:
89134
if k == key:
90135
result = v
136+
break
91137
return result
138+
139+
140+
def _assert_no_duplicate_kwargs(stmt: GymStatement) -> None:
141+
"""Assert that a statement has no duplicate keyword argument names.
142+
143+
Args:
144+
stmt: GymStatement to check.
145+
146+
Raises:
147+
AssertionError: If duplicate kwarg keys are found.
148+
"""
149+
keys = [k for k, _ in stmt.kwargs]
150+
assert len(keys) == len(set(keys)), f"Duplicate kwargs in {stmt.op} stmt '{stmt.output.name}': {keys}"

nkigym/src/nkigym/codegen/loop_rolling.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,41 @@ def _extract_varying(
228228
return (valid, varying)
229229

230230

231+
def _collect_assigned_names(stmts: list[ast.stmt]) -> set[str]:
232+
"""Collect all assignment target names from a list of statements."""
233+
names: set[str] = set()
234+
for stmt in stmts:
235+
if isinstance(stmt, ast.Assign):
236+
for target in stmt.targets:
237+
if isinstance(target, ast.Name):
238+
names.add(target.id)
239+
return names
240+
241+
242+
def _collect_referenced_names(stmts: list[ast.stmt]) -> set[str]:
243+
"""Collect all Name references (loads) from a list of statements."""
244+
names: set[str] = set()
245+
for stmt in stmts:
246+
for node in ast.walk(stmt):
247+
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
248+
names.add(node.id)
249+
return names
250+
251+
252+
def _check_scope_safe(working_stmts: list[ast.stmt], start_idx: int, block_size: int, trip_count: int) -> bool:
253+
"""Check that rolling a run won't hide definitions used after the loop.
254+
255+
Returns False if any variable defined inside the rolled region is
256+
referenced by statements after the rolled region.
257+
"""
258+
end_idx = start_idx + trip_count * block_size
259+
rolled_region = working_stmts[start_idx:end_idx]
260+
after_region = working_stmts[end_idx:]
261+
defined = _collect_assigned_names(rolled_region)
262+
used_after = _collect_referenced_names(after_region)
263+
return not (defined & used_after)
264+
265+
231266
def _count_matching_blocks(
232267
working_stmts: list[ast.stmt], start: int, block_size: int, n: int, cache: dict[int, str]
233268
) -> int:
@@ -271,7 +306,7 @@ def _find_best_run(working_stmts: list[ast.stmt]) -> _LoopRun:
271306
count = _count_matching_blocks(working_stmts, p, k, n, cache)
272307
if count >= 2 and count * k > best_coverage:
273308
valid, varying = _extract_varying(working_stmts, k, count, p)
274-
if valid:
309+
if valid and _check_scope_safe(working_stmts, p, k, count):
275310
best = _LoopRun(p, k, count, varying)
276311
best_coverage = count * k
277312
p += 1

nkigym/src/nkigym/ops/activation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
7373

7474
data = ctx.subscript(data_ref)
7575
out_name = stmt.output.name
76+
out_sub = ctx.subscript(stmt.output)
7677
ctx.buffers[out_name] = "SBUF"
7778

7879
func_str = "nl.identity"
@@ -82,5 +83,5 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
8283
shape_str = repr(stmt.output.shape)
8384
return [
8485
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
85-
f"nisa.activation(dst={out_name}, op={func_str}, data={data})",
86+
f"nisa.activation(dst={out_sub}, op={func_str}, data={data})",
8687
]

nkigym/src/nkigym/ops/matmul.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,20 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
6969
stat_name = ctx.subscript(stat_ref)
7070
mov_name = ctx.subscript(mov_ref)
7171
out_name = stmt.output.name
72+
out_sub = ctx.subscript(stmt.output)
7273
ctx.buffers[out_name] = "PSUM"
7374

7475
lines: list[str] = []
7576
if isinstance(acc_ref, TensorRef):
7677
canonical = ctx.resolve(acc_ref.name)
78+
acc_sub = ctx.subscript(acc_ref)
7779
ctx.aliases[out_name] = canonical
78-
lines = [f"nisa.nc_matmul(dst={canonical}, stationary={stat_name}, moving={mov_name})"]
80+
ctx.alias_offsets[out_name] = tuple(s for s, _ in acc_ref.slices)
81+
lines = [f"nisa.nc_matmul(dst={acc_sub}, stationary={stat_name}, moving={mov_name})"]
7982
else:
8083
shape_str = repr(stmt.output.shape)
8184
lines = [
8285
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.psum)",
83-
f"nisa.nc_matmul(dst={out_name}, stationary={stat_name}, moving={mov_name})",
86+
f"nisa.nc_matmul(dst={out_sub}, stationary={stat_name}, moving={mov_name})",
8487
]
8588
return lines

nkigym/src/nkigym/ops/nc_transpose.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
5353

5454
data = ctx.subscript(data_ref)
5555
out_name = stmt.output.name
56+
out_sub = ctx.subscript(stmt.output)
5657
ctx.buffers[out_name] = "SBUF"
5758

5859
shape_str = repr(stmt.output.shape)
5960
return [
6061
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
61-
f"nisa.nc_transpose(dst={out_name}, data={data})",
62+
f"nisa.nc_transpose(dst={out_sub}, data={data})",
6263
]

nkigym/src/nkigym/ops/tensor_scalar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
6464

6565
data = ctx.subscript(data_ref)
6666
out_name = stmt.output.name
67+
out_sub = ctx.subscript(stmt.output)
6768
ctx.buffers[out_name] = "SBUF"
6869

6970
operand = str(operand_ref)
@@ -78,5 +79,5 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
7879
shape_str = repr(stmt.output.shape)
7980
return [
8081
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
81-
f"nisa.tensor_scalar(dst={out_name}, data={data}{op_kwarg}, operand0={operand})",
82+
f"nisa.tensor_scalar(dst={out_sub}, data={data}{op_kwarg}, operand0={operand})",
8283
]

nkigym/src/nkigym/ops/tensor_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
6565
d1 = ctx.subscript(d1_ref)
6666
d2 = ctx.subscript(d2_ref)
6767
out_name = stmt.output.name
68+
out_sub = ctx.subscript(stmt.output)
6869
ctx.buffers[out_name] = "SBUF"
6970

7071
op_part = ""
@@ -75,5 +76,5 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
7576
shape_str = repr(stmt.output.shape)
7677
return [
7778
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
78-
f"nisa.tensor_tensor(dst={out_name}, data1={d1}, data2={d2}{op_part})",
79+
f"nisa.tensor_tensor(dst={out_sub}, data1={d1}, data2={d2}{op_part})",
7980
]

nkigym/src/nkigym/ops/tiling_ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
104104
shape_str = repr(stmt.output.shape)
105105
src_subscript = ctx.subscript(src_ref)
106106

107+
out_sub = ctx.subscript(stmt.output)
107108
ctx.buffers[out_name] = "SBUF"
108109
lines = [f"{out_name} = {src_subscript}"]
109110
if src_buffer != "SBUF":
110111
lines = [
111112
f"{out_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
112-
f"nisa.dma_copy(dst={out_name}, src={src_subscript})",
113+
f"nisa.dma_copy(dst={out_sub}, src={src_subscript})",
113114
]
114115
return lines
115116

@@ -184,9 +185,11 @@ def to_nki(self, stmt: "GymStatement", ctx: "_LoweringContext") -> list[str]:
184185
staging_name = f"_staging_{ctx.staging_counter}"
185186
ctx.staging_counter += 1
186187
shape_str = repr(src_ref.shape)
188+
parts = ", ".join(f"0:{s}" for s in src_ref.shape)
189+
staging_sub = f"{staging_name}[{parts}]"
187190
lines = [
188191
f"{staging_name} = nl.ndarray({shape_str}, dtype=nl.float32, buffer=nl.sbuf)",
189-
f"nisa.tensor_copy(dst={staging_name}, src={src_subscript})",
190-
f"nisa.dma_copy(dst={dst_subscript}, src={staging_name})",
192+
f"nisa.tensor_copy(dst={staging_sub}, src={src_subscript})",
193+
f"nisa.dma_copy(dst={dst_subscript}, src={staging_sub})",
191194
]
192195
return lines

nkigym/src/nkigym/search/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
provides systematic exploration and sampling of that search space.
88
"""
99

10-
from nkigym.search.benchmark import benchmark_variants
10+
from nkigym.search.compile import SearchResults
1111
from nkigym.search.search import search
1212

13-
__all__ = ["benchmark_variants", "search"]
13+
__all__ = ["SearchResults", "search"]

0 commit comments

Comments
 (0)