Skip to content

Commit 3bd6c96

Browse files
weiT1993claude
andcommitted
Refactor kernelgen.py to use nkipy_kernelgen trace decorator with NumPy ops
Extract original ComputeGraph-based implementation to compute_graph.py. Update README with usage example and additional references. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 0fc4948 commit 3bd6c96

3 files changed

Lines changed: 151 additions & 90 deletions

File tree

compute_graph/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1+
## Usage Example
2+
3+
See [examples/kernelgen.py](../examples/kernelgen.py) for a NumPy workload spec (RMSNorm + Matmul).
4+
15
## References
26

37
- [NKI Gym Proposal](../../documents/nkigym/NKI_Gym.docx) - Design and architecture
48
- [NKI Documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/nki/index.html)
9+
- [NKIPy KernelGen](../../NKIPyKernelGen/examples/07_full_pipeline.py)
10+
11+
## Python Venv
12+
```
13+
/home/ubuntu/venvs/kernel-env/bin/python
14+
```

examples/compute_graph.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
import os
3+
4+
import numpy as np
5+
6+
from autotune.core.benchmark import Benchmark
7+
from autotune.core.job import ProfileJobs
8+
from autotune.core.metrics import check_correctness
9+
from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE
10+
from compute_graph.codegen import NKICodegen
11+
from compute_graph.graph import ComputeGraph
12+
from compute_graph.node.compute import Activation, Matmul, TensorScalar
13+
from compute_graph.visualize import save_graph, setup_logging
14+
15+
cache_root = os.environ.get("NKI_CACHE_ROOT", "/fsx/weittang/kernelgen_cache")
16+
setup_logging(f"{cache_root}/debug.log")
17+
logger = logging.getLogger(__name__)
18+
19+
RMSNORM_EPSILON = 1e-6
20+
21+
22+
def rmsnorm_gemm_correctness(
23+
input_tensors: INPUT_TENSORS_DTYPE, kernel_kwargs: KERNEL_KWARGS_DTYPE, kernel_outputs: OUTPUT_TENSORS_DTYPE
24+
) -> None:
25+
"""Postprocessing function to verify RMSNorm + GEMM correctness.
26+
27+
Computes golden reference: output = RMSNorm(lhs) @ rhs
28+
where RMSNorm(x) = x / sqrt(sum(x^2) / K + epsilon)
29+
"""
30+
lhs, rhs = input_tensors.values()
31+
K = lhs.shape[-1]
32+
33+
lhs_square = np.square(lhs)
34+
lhs_sum_square = np.sum(lhs_square, axis=-1, keepdims=True)
35+
rmsnorm_factor = 1.0 / np.sqrt(lhs_sum_square / K + RMSNORM_EPSILON)
36+
lhs_norm = lhs * rmsnorm_factor
37+
golden = np.matmul(lhs_norm, rhs).astype(np.float32)
38+
39+
nki_out = kernel_outputs[0].astype(np.float32)
40+
check_correctness(golden, nki_out, atol=1e-4, rtol=1e-2)
41+
42+
43+
def generate_kernel(M: int, K: int, N: int) -> tuple[str, str, dict[str, tuple[int, ...]]]:
44+
"""Generate RMSNorm + Matmul kernel code.
45+
46+
Args:
47+
M: Number of rows in LHS matrix
48+
K: Number of columns in LHS / rows in RHS matrix
49+
N: Number of columns in RHS matrix
50+
51+
Returns:
52+
Tuple of (kernel_path, kernel_name, input_tensor_shapes)
53+
"""
54+
input_tensor_shapes = {"lhs_hbm": (M, K), "rhs_hbm": (K, N)}
55+
56+
rmsnorm_matmul_graph = ComputeGraph(
57+
operators=[
58+
Activation(dest="lhs_square", op="np.square", data="lhs", reduce_op="np.add", reduce_res="lhs_sum_square"),
59+
TensorScalar(
60+
dest="rmsnorm_factor",
61+
data="lhs_sum_square",
62+
op0="np.multiply",
63+
operand0=1 / K,
64+
op1="np.add",
65+
operand1=RMSNORM_EPSILON,
66+
),
67+
Activation(dest="rmsnorm_factor", op="nl.rsqrt", data="rmsnorm_factor"),
68+
TensorScalar(dest="lhs_norm", data="lhs", op0="np.multiply", operand0="rmsnorm_factor"),
69+
Matmul(dest="output", lhs="lhs_norm", rhs="rhs", lhs_transposed=False),
70+
],
71+
input_shapes={"lhs": (M, K), "rhs": (K, N)},
72+
output="output",
73+
)
74+
save_graph(rmsnorm_matmul_graph, output_dir=f"{cache_root}", title="RMSNorm + Matmul")
75+
kernel_name = "rmsnorm_matmul_kernel"
76+
kernel_path = f"{cache_root}/{kernel_name}.py"
77+
kernel_code = NKICodegen(rmsnorm_matmul_graph, kernel_name).code
78+
with open(kernel_path, "w") as f:
79+
f.write(kernel_code)
80+
81+
return kernel_path, kernel_name, input_tensor_shapes
82+
83+
84+
def run_benchmark(
85+
kernel_path: str, kernel_name: str, input_tensor_shapes: dict[str, tuple[int, ...]], mac_count: int
86+
) -> None:
87+
"""Benchmark the generated kernel.
88+
89+
Args:
90+
kernel_path: Path to the generated kernel file
91+
kernel_name: Name of the kernel function
92+
input_tensor_shapes: Dict mapping input names to shapes
93+
mac_count: Number of multiply-accumulate operations for MFU calculation
94+
"""
95+
jobs = ProfileJobs(cache_root_dir=cache_root, target_instance_family="trn2")
96+
jobs.add_job(
97+
kernel=(kernel_path, kernel_name),
98+
input_tensor_shapes=input_tensor_shapes,
99+
data_type=np.float32,
100+
kernel_kwargs={},
101+
compiler_flags="--auto-cast=none --internal-tensorizer-opt-level=nki",
102+
postprocessing=rmsnorm_gemm_correctness,
103+
mac_count=mac_count,
104+
)
105+
benchmark = Benchmark(jobs=jobs)
106+
benchmark()
107+
108+
109+
if __name__ == "__main__":
110+
M, K, N = 256, 128, 128
111+
kernel_path, kernel_name, input_tensor_shapes = generate_kernel(M, K, N)
112+
# run_benchmark(kernel_path, kernel_name, input_tensor_shapes, mac_count=M * N * K)

examples/kernelgen.py

Lines changed: 29 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,51 @@
1+
"""NumPy workload specification for RMSNorm + Matmul fusion.
2+
3+
This demonstrates how to define kernels using standard NumPy operations
4+
and lower them to MLIR using NKIPy KernelGen.
5+
6+
See compute_graph/README.md for installation instructions.
7+
"""
8+
19
import logging
210
import os
311

412
import numpy as np
13+
from nkipy_kernelgen import trace
514

6-
from autotune.core.benchmark import Benchmark
7-
from autotune.core.job import ProfileJobs
8-
from autotune.core.metrics import check_correctness
9-
from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE
10-
from compute_graph.codegen import NKICodegen
11-
from compute_graph.graph import ComputeGraph
12-
from compute_graph.node.compute import Activation, Matmul, TensorScalar
13-
from compute_graph.visualize import save_graph, setup_logging
15+
from compute_graph.visualize import setup_logging
1416

15-
cache_root = os.environ.get("NKI_CACHE_ROOT", "/fsx/weittang/kernelgen_cache")
17+
cache_root = "/fsx/weittang/kernelgen_cache"
18+
os.makedirs(cache_root, exist_ok=True)
1619
setup_logging(f"{cache_root}/debug.log")
1720
logger = logging.getLogger(__name__)
1821

1922
RMSNORM_EPSILON = 1e-6
2023

24+
# Matrix dimensions
25+
M, K, N = 256, 128, 128
2126

22-
def rmsnorm_gemm_correctness(
23-
input_tensors: INPUT_TENSORS_DTYPE, kernel_kwargs: KERNEL_KWARGS_DTYPE, kernel_outputs: OUTPUT_TENSORS_DTYPE
24-
) -> None:
25-
"""Postprocessing function to verify RMSNorm + GEMM correctness.
26-
27-
Computes golden reference: output = RMSNorm(lhs) @ rhs
28-
where RMSNorm(x) = x / sqrt(sum(x^2) / K + epsilon)
29-
"""
30-
lhs, rhs = input_tensors.values()
31-
K = lhs.shape[-1]
32-
33-
lhs_square = np.square(lhs)
34-
lhs_sum_square = np.sum(lhs_square, axis=-1, keepdims=True)
35-
rmsnorm_factor = 1.0 / np.sqrt(lhs_sum_square / K + RMSNORM_EPSILON)
36-
lhs_norm = lhs * rmsnorm_factor
37-
golden = np.matmul(lhs_norm, rhs).astype(np.float32)
38-
39-
nki_out = kernel_outputs[0].astype(np.float32)
40-
check_correctness(golden, nki_out, atol=1e-4, rtol=1e-2)
4127

28+
@trace(input_specs=[((M, K), "f32"), ((K, N), "f32")])
29+
def rmsnorm_matmul(lhs, rhs):
30+
"""Fused RMSNorm + Matmul: output = RMSNorm(lhs) @ rhs
4231
43-
def generate_kernel(M: int, K: int, N: int) -> tuple[str, str, dict[str, tuple[int, ...]]]:
44-
"""Generate RMSNorm + Matmul kernel code.
32+
RMSNorm(x) = x / sqrt(mean(x^2) + epsilon)
4533
4634
Args:
47-
M: Number of rows in LHS matrix
48-
K: Number of columns in LHS / rows in RHS matrix
49-
N: Number of columns in RHS matrix
35+
lhs: Input tensor of shape (M, K)
36+
rhs: Weight tensor of shape (K, N)
5037
5138
Returns:
52-
Tuple of (kernel_path, kernel_name, input_tensor_shapes)
39+
Output tensor of shape (M, N)
5340
"""
54-
input_tensor_shapes = {"lhs_hbm": (M, K), "rhs_hbm": (K, N)}
55-
56-
rmsnorm_matmul_graph = ComputeGraph(
57-
operators=[
58-
Activation(dest="lhs_square", op="np.square", data="lhs", reduce_op="np.add", reduce_res="lhs_sum_square"),
59-
TensorScalar(
60-
dest="rmsnorm_factor",
61-
data="lhs_sum_square",
62-
op0="np.multiply",
63-
operand0=1 / K,
64-
op1="np.add",
65-
operand1=RMSNORM_EPSILON,
66-
),
67-
Activation(dest="rmsnorm_factor", op="nl.rsqrt", data="rmsnorm_factor"),
68-
TensorScalar(dest="lhs_norm", data="lhs", op0="np.multiply", operand0="rmsnorm_factor"),
69-
Matmul(dest="output", lhs="lhs_norm", rhs="rhs", lhs_transposed=False),
70-
],
71-
input_shapes={"lhs": (M, K), "rhs": (K, N)},
72-
output="output",
73-
)
74-
save_graph(rmsnorm_matmul_graph, output_dir=f"{cache_root}", title="RMSNorm + Matmul")
75-
kernel_name = "rmsnorm_matmul_kernel"
76-
kernel_path = f"{cache_root}/{kernel_name}.py"
77-
kernel_code = NKICodegen(rmsnorm_matmul_graph, kernel_name).code
78-
with open(kernel_path, "w") as f:
79-
f.write(kernel_code)
80-
81-
return kernel_path, kernel_name, input_tensor_shapes
82-
83-
84-
def run_benchmark(
85-
kernel_path: str, kernel_name: str, input_tensor_shapes: dict[str, tuple[int, ...]], mac_count: int
86-
) -> None:
87-
"""Benchmark the generated kernel.
88-
89-
Args:
90-
kernel_path: Path to the generated kernel file
91-
kernel_name: Name of the kernel function
92-
input_tensor_shapes: Dict mapping input names to shapes
93-
mac_count: Number of multiply-accumulate operations for MFU calculation
94-
"""
95-
jobs = ProfileJobs(cache_root_dir=cache_root, target_instance_family="trn2")
96-
jobs.add_job(
97-
kernel=(kernel_path, kernel_name),
98-
input_tensor_shapes=input_tensor_shapes,
99-
data_type=np.float32,
100-
kernel_kwargs={},
101-
compiler_flags="--auto-cast=none --internal-tensorizer-opt-level=nki",
102-
postprocessing=rmsnorm_gemm_correctness,
103-
mac_count=mac_count,
104-
)
105-
benchmark = Benchmark(jobs=jobs)
106-
benchmark()
41+
K_dim = lhs.shape[-1]
42+
lhs_square = np.square(lhs)
43+
lhs_sum_square = np.sum(lhs_square, axis=-1, keepdims=True)
44+
rmsnorm_factor = 1.0 / np.sqrt(lhs_sum_square / K_dim + RMSNORM_EPSILON)
45+
lhs_norm = lhs * rmsnorm_factor
46+
return np.matmul(lhs_norm, rhs)
10747

10848

10949
if __name__ == "__main__":
110-
M, K, N = 256, 128, 128
111-
kernel_path, kernel_name, input_tensor_shapes = generate_kernel(M, K, N)
112-
# run_benchmark(kernel_path, kernel_name, input_tensor_shapes, mac_count=M * N * K)
50+
mlir_module = rmsnorm_matmul.to_mlir()
51+
logger.info(mlir_module)

0 commit comments

Comments
 (0)