Skip to content

Commit dbece48

Browse files
committed
Add multi-pass reduction support with rmsnorm+matmul ops, renderer, and simulator extensions
1 parent c4673ff commit dbece48

30 files changed

+1994
-183
lines changed

examples/gym.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,20 @@ def matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
3333
def parse_args() -> argparse.Namespace:
3434
"""Parse command-line arguments."""
3535
parser = argparse.ArgumentParser(description="NKI Gym search example")
36-
parser.add_argument(
37-
"--cache-dir",
38-
type=Path,
39-
default=Path("/fsx/weittang/gym_cache"),
40-
help="Directory for storing output (default: /fsx/weittang/gym_cache)",
41-
)
36+
parser.add_argument("--cache-dir", type=Path, required=True, help="Directory for storing output artifacts")
4237
return parser.parse_args()
4338

4439

4540
def main() -> None:
46-
"""Run schedule search on a 2048x2048 matmul workload."""
41+
"""Run schedule search on a 1024x1024 matmul workload."""
4742
logging.basicConfig(level=logging.INFO, format="%(message)s")
4843

4944
args = parse_args()
5045
cache_dir = args.cache_dir
5146

5247
rng = np.random.default_rng(42)
53-
a = rng.standard_normal((2048, 2048)).astype(np.float16)
54-
b = rng.standard_normal((2048, 2048)).astype(np.float16)
48+
a = rng.standard_normal((1024, 1024)).astype(np.float16)
49+
b = rng.standard_normal((1024, 1024)).astype(np.float16)
5550

5651
search(func=matmul, num_targets=99999, seed=42, save_cache=cache_dir, kernel_kwargs={"a": a, "b": b})
5752

examples/rmsnorm_matmul.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""NKI Gym search: rmsnorm + matmul two-pass reduction kernel.
2+
3+
Demonstrates multi-pass schedule search: RMSNorm (activation+reduce
4+
over K, then normalize) followed by matrix multiply, producing two
5+
sequential reduction passes over the same dimension.
6+
"""
7+
8+
import argparse
9+
import logging
10+
from pathlib import Path
11+
12+
import numpy as np
13+
14+
import nkigym
15+
from nkigym.search import search
16+
17+
18+
def rmsnorm_matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
19+
"""RMSNorm(a) @ b: normalize rows of a then multiply by b.
20+
21+
Args:
22+
a: Input tensor of shape [M, K].
23+
b: Weight tensor of shape [K, N].
24+
25+
Returns:
26+
Output tensor of shape [M, N].
27+
"""
28+
sum_sq = nkigym.activation(a, op="square", reduce_op=np.add)
29+
scaled = nkigym.tensor_scalar(sum_sq, op0=np.multiply, operand0=1 / 1024, op1=np.add, operand1=1e-6)
30+
rsqrt_val = nkigym.activation(scaled, op="rsqrt")
31+
a_normed = nkigym.tensor_scalar(a, rsqrt_val, op0=np.multiply)
32+
a_t = nkigym.transpose(a_normed)
33+
result = nkigym.nc_matmul(a_t, b)
34+
return result
35+
36+
37+
def parse_args() -> argparse.Namespace:
38+
"""Parse command-line arguments."""
39+
parser = argparse.ArgumentParser(description="NKI Gym rmsnorm+matmul search")
40+
parser.add_argument("--cache-dir", type=Path, required=True, help="Directory for storing output artifacts")
41+
return parser.parse_args()
42+
43+
44+
def main() -> None:
45+
"""Run schedule search on a 1024x1024 rmsnorm+matmul workload."""
46+
logging.basicConfig(level=logging.INFO, format="%(message)s")
47+
48+
args = parse_args()
49+
cache_dir = args.cache_dir
50+
51+
rng = np.random.default_rng(42)
52+
a = rng.standard_normal((1024, 1024)).astype(np.float16)
53+
b = rng.standard_normal((1024, 1024)).astype(np.float16)
54+
55+
search(func=rmsnorm_matmul, num_targets=99999, seed=42, save_cache=cache_dir, kernel_kwargs={"a": a, "b": b})
56+
57+
58+
if __name__ == "__main__":
59+
main()

nkigym/src/nkigym/__init__.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import numpy as np
1616

1717
from nkigym.ops.activation import NKIActivation
18+
from nkigym.ops.activation_1d import NKIActivation1D
19+
from nkigym.ops.activation_reduce import NKIActivationReduce
1820
from nkigym.ops.add import NKIAdd
1921
from nkigym.ops.base import NKIOp
2022
from nkigym.ops.dma_copy import NKIDmaCopy
@@ -23,6 +25,8 @@
2325
from nkigym.ops.tensor_copy import NKITensorCopy
2426
from nkigym.ops.tensor_reduce import NKITensorReduce
2527
from nkigym.ops.tensor_scalar import NKITensorScalar
28+
from nkigym.ops.tensor_scalar_const import NKITensorScalarConst
29+
from nkigym.ops.transpose import NKITranspose
2630

2731

2832
def nc_matmul(*args: Any, **kwargs: Any) -> Any:
@@ -35,15 +39,27 @@ def nc_matmul(*args: Any, **kwargs: Any) -> Any:
3539
return np.matmul(stationary.T, moving)
3640

3741

42+
def _rsqrt(x: Any) -> Any:
43+
"""Reciprocal square root: 1 / sqrt(x)."""
44+
return 1.0 / np.sqrt(x)
45+
46+
47+
_STR_OPS: dict[str, Any] = {"square": np.square, "rsqrt": _rsqrt}
48+
49+
3850
def activation(*args: Any, **kwargs: Any) -> Any:
39-
"""Apply element-wise activation function.
51+
"""Apply element-wise activation, optionally with reduction.
4052
4153
Returns:
42-
Activated numpy array.
54+
Activated numpy array, or reduced 1D array if reduce_op given.
4355
"""
4456
data = args[0]
4557
op_fn = kwargs.get("op")
46-
result = op_fn(data) if op_fn is not None else data
58+
if isinstance(op_fn, str):
59+
op_fn = _STR_OPS[op_fn]
60+
activated = op_fn(data) if op_fn is not None else data
61+
reduce_op = kwargs.get("reduce_op")
62+
result = reduce_op.reduce(activated, axis=-1) if reduce_op is not None else activated
4763
return result
4864

4965

@@ -76,15 +92,34 @@ def tensor_reduce(*args: Any, **kwargs: Any) -> Any:
7692
return op_fn.reduce(data, axis=-1)
7793

7894

95+
def _expand_operand(data: Any, operand0: Any) -> Any:
96+
"""Expand operand0 for broadcasting against data if needed."""
97+
result = operand0
98+
if isinstance(operand0, np.ndarray) and data.ndim > operand0.ndim:
99+
pad = data.ndim - operand0.ndim
100+
result = operand0.reshape(operand0.shape + (1,) * pad)
101+
return result
102+
103+
79104
def tensor_scalar(*args: Any, **kwargs: Any) -> Any:
80-
"""Element-wise op between a tensor and a column vector.
105+
"""Element-wise op between a tensor and a scalar/column vector.
106+
107+
Supports two modes:
108+
- 2D broadcast: ``tensor_scalar(data, tensor_operand, op0=...)``
109+
- 1D compound: ``tensor_scalar(data, op0=..., operand0=literal, ...)``
81110
82111
Returns:
83-
Result numpy array (same shape as data).
112+
Result numpy array.
84113
"""
85-
data, operand0 = args[0], args[1]
86-
op_fn = kwargs.get("op0", np.add)
87-
return op_fn(data, operand0[..., np.newaxis])
114+
data = args[0]
115+
op0 = kwargs.get("op0", np.add)
116+
operand0 = args[1] if len(args) > 1 else kwargs["operand0"]
117+
expanded = _expand_operand(data, operand0)
118+
result = op0(data, expanded)
119+
op1 = kwargs.get("op1")
120+
if op1 is not None:
121+
result = op1(result, kwargs["operand1"])
122+
return result
88123

89124

90125
def transpose(x: Any) -> Any:
@@ -110,12 +145,16 @@ def ndarray(shape: tuple[int, ...], **kwargs: Any) -> np.ndarray:
110145
"NKIOp",
111146
"NKIMatmul",
112147
"NKIActivation",
148+
"NKIActivation1D",
149+
"NKIActivationReduce",
113150
"NKIAdd",
114151
"NKIDmaCopy",
115152
"NKIMultiply",
116153
"NKITensorCopy",
117154
"NKITensorReduce",
118155
"NKITensorScalar",
156+
"NKITensorScalarConst",
157+
"NKITranspose",
119158
"nc_matmul",
120159
"activation",
121160
"add",

nkigym/src/nkigym/codegen/parse.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@
55
"""
66

77
import ast
8+
import operator
89

910
import numpy as np
1011

1112
from nkigym.codegen.analysis import _OpCall
13+
from nkigym.ops.activation import NKIActivation
14+
from nkigym.ops.activation_1d import NKIActivation1D
1215
from nkigym.ops.base import NKIOp
1316

17+
_BINOP_FNS: dict[type, object] = {
18+
ast.Add: operator.add,
19+
ast.Sub: operator.sub,
20+
ast.Mult: operator.mul,
21+
ast.Div: operator.truediv,
22+
}
23+
1424

1525
def find_func_def(source: str) -> ast.FunctionDef:
1626
"""Find the first FunctionDef in parsed source.
@@ -41,10 +51,28 @@ def _is_nkigym_call(call: ast.Call) -> bool:
4151
return isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name) and func.value.id == "nkigym"
4252

4353

54+
def _eval_binop(node: ast.BinOp) -> object:
55+
"""Evaluate a binary operation on constant operands.
56+
57+
Args:
58+
node: AST BinOp node.
59+
60+
Returns:
61+
Result of the binary operation.
62+
"""
63+
left = _eval_expr(node.left)
64+
right = _eval_expr(node.right)
65+
op_fn = _BINOP_FNS.get(type(node.op))
66+
if op_fn is None:
67+
raise ValueError(f"Unsupported binary op: {ast.dump(node)}")
68+
return op_fn(left, right)
69+
70+
4471
def _eval_expr(node: ast.expr) -> object:
4572
"""Evaluate an AST expression to a Python object.
4673
47-
Resolves ``np.X`` attribute accesses and literal constants.
74+
Resolves ``np.X`` attribute accesses, literal constants,
75+
binary operations, and unary negation.
4876
4977
Args:
5078
node: AST expression node.
@@ -53,11 +81,14 @@ def _eval_expr(node: ast.expr) -> object:
5381
The resolved Python object.
5482
"""
5583
result = None
56-
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
57-
if node.value.id == "np":
58-
result = getattr(np, node.attr)
84+
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "np":
85+
result = getattr(np, node.attr)
5986
elif isinstance(node, ast.Constant):
6087
result = node.value
88+
elif isinstance(node, ast.BinOp):
89+
result = _eval_binop(node)
90+
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
91+
result = -_eval_expr(node.operand)
6192
if result is None:
6293
raise ValueError(f"Unsupported kwarg expression: {ast.dump(node)}")
6394
return result
@@ -77,6 +108,45 @@ def _arg_name(node: ast.expr) -> str:
77108
return node.id
78109

79110

111+
def _maybe_reclassify_activation(op: _OpCall, output_axes_map: dict[str, tuple[str, ...]]) -> _OpCall:
112+
"""Reclassify NKIActivation to NKIActivation1D if input is 1D.
113+
114+
Args:
115+
op: Parsed op call to check.
116+
output_axes_map: Maps variable name to output axes of its producer op.
117+
118+
Returns:
119+
Original or reclassified op call.
120+
"""
121+
is_1d = (
122+
op.stmt_type is NKIActivation
123+
and op.input_vars[0] in output_axes_map
124+
and len(output_axes_map[op.input_vars[0]]) == 1
125+
)
126+
return op._replace(stmt_type=NKIActivation1D) if is_1d else op
127+
128+
129+
def _resolve_op_variants(op_calls: list[_OpCall]) -> list[_OpCall]:
130+
"""Post-parse pass to reclassify ops based on producer output shapes.
131+
132+
Traces the SSA chain to determine operand dimensionality and
133+
reclassifies NKIActivation to NKIActivation1D when input is 1D.
134+
135+
Args:
136+
op_calls: Parsed op calls from the function body.
137+
138+
Returns:
139+
Op calls with reclassified types where appropriate.
140+
"""
141+
output_axes_map: dict[str, tuple[str, ...]] = {}
142+
result: list[_OpCall] = []
143+
for op in op_calls:
144+
resolved = _maybe_reclassify_activation(op, output_axes_map)
145+
output_axes_map[resolved.output_var] = getattr(resolved.stmt_type, "OUTPUT_AXES", ())
146+
result.append(resolved)
147+
return result
148+
149+
80150
def parse_body(func_def: ast.FunctionDef) -> list[_OpCall]:
81151
"""Parse function body into a list of _OpCall.
82152
@@ -91,7 +161,7 @@ def parse_body(func_def: ast.FunctionDef) -> list[_OpCall]:
91161
for node in func_def.body:
92162
if not _try_parse_node(node, op_calls, counter):
93163
raise ValueError(f"Unsupported statement: {ast.dump(node)}")
94-
return op_calls
164+
return _resolve_op_variants(op_calls)
95165

96166

97167
def _try_parse_node(node: ast.stmt, op_calls: list[_OpCall], counter: list[int]) -> bool:
@@ -157,6 +227,28 @@ def _try_parse_return(node: ast.Return, op_calls: list[_OpCall], counter: list[i
157227
return result
158228

159229

230+
def _disambiguate_op(op_name: str, call: ast.Call) -> str:
231+
"""Disambiguate user function name to internal op registry key.
232+
233+
- ``activation`` with ``reduce_op`` kwarg → ``activation_reduce``
234+
- ``tensor_scalar`` with < 2 positional args → ``tensor_scalar_const``
235+
236+
Args:
237+
op_name: User-facing function name from AST.
238+
call: AST Call node with keyword arguments.
239+
240+
Returns:
241+
Internal op registry key.
242+
"""
243+
kwarg_names = {kw.arg for kw in call.keywords}
244+
result = op_name
245+
if op_name == "activation" and "reduce_op" in kwarg_names:
246+
result = "activation_reduce"
247+
elif op_name == "tensor_scalar" and len(call.args) < 2:
248+
result = "tensor_scalar_const"
249+
return result
250+
251+
160252
def _flatten_call(call: ast.Call, output: str, op_calls: list[_OpCall], counter: list[int]) -> None:
161253
"""Flatten a nkigym call (possibly nested) into _OpCall entries.
162254
@@ -167,7 +259,7 @@ def _flatten_call(call: ast.Call, output: str, op_calls: list[_OpCall], counter:
167259
counter: Mutable counter for intermediate variable names.
168260
"""
169261
assert isinstance(call.func, ast.Attribute)
170-
op_name = call.func.attr
262+
op_name = _disambiguate_op(call.func.attr, call)
171263
registry = NKIOp.all_ops()
172264
if op_name not in registry:
173265
raise ValueError(f"Unknown op: {op_name!r}")

0 commit comments

Comments
 (0)