|
| 1 | +"""NumPy workload specification for RMSNorm + Matmul fusion. |
| 2 | +
|
| 3 | +This demonstrates how to define kernels using standard NumPy operations |
| 4 | +and lower them to MLIR using NKIPy KernelGen. |
| 5 | +
|
| 6 | +See compute_graph/README.md for installation instructions. |
| 7 | +""" |
| 8 | + |
1 | 9 | import logging |
2 | 10 | import os |
3 | 11 |
|
4 | 12 | import numpy as np |
| 13 | +from nkipy_kernelgen import trace |
5 | 14 |
|
6 | | -from autotune.core.benchmark import Benchmark |
7 | | -from autotune.core.job import ProfileJobs |
8 | | -from autotune.core.metrics import check_correctness |
9 | | -from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE |
10 | | -from compute_graph.codegen import NKICodegen |
11 | | -from compute_graph.graph import ComputeGraph |
12 | | -from compute_graph.node.compute import Activation, Matmul, TensorScalar |
13 | | -from compute_graph.visualize import save_graph, setup_logging |
| 15 | +from compute_graph.visualize import setup_logging |
14 | 16 |
|
15 | | -cache_root = os.environ.get("NKI_CACHE_ROOT", "/fsx/weittang/kernelgen_cache") |
| 17 | +cache_root = "/fsx/weittang/kernelgen_cache" |
| 18 | +os.makedirs(cache_root, exist_ok=True) |
16 | 19 | setup_logging(f"{cache_root}/debug.log") |
17 | 20 | logger = logging.getLogger(__name__) |
18 | 21 |
|
19 | 22 | RMSNORM_EPSILON = 1e-6 |
20 | 23 |
|
| 24 | +# Matrix dimensions |
| 25 | +M, K, N = 256, 128, 128 |
21 | 26 |
|
22 | | -def rmsnorm_gemm_correctness( |
23 | | - input_tensors: INPUT_TENSORS_DTYPE, kernel_kwargs: KERNEL_KWARGS_DTYPE, kernel_outputs: OUTPUT_TENSORS_DTYPE |
24 | | -) -> None: |
25 | | - """Postprocessing function to verify RMSNorm + GEMM correctness. |
26 | | -
|
27 | | - Computes golden reference: output = RMSNorm(lhs) @ rhs |
28 | | - where RMSNorm(x) = x / sqrt(sum(x^2) / K + epsilon) |
29 | | - """ |
30 | | - lhs, rhs = input_tensors.values() |
31 | | - K = lhs.shape[-1] |
32 | | - |
33 | | - lhs_square = np.square(lhs) |
34 | | - lhs_sum_square = np.sum(lhs_square, axis=-1, keepdims=True) |
35 | | - rmsnorm_factor = 1.0 / np.sqrt(lhs_sum_square / K + RMSNORM_EPSILON) |
36 | | - lhs_norm = lhs * rmsnorm_factor |
37 | | - golden = np.matmul(lhs_norm, rhs).astype(np.float32) |
38 | | - |
39 | | - nki_out = kernel_outputs[0].astype(np.float32) |
40 | | - check_correctness(golden, nki_out, atol=1e-4, rtol=1e-2) |
41 | 27 |
|
| 28 | +@trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) |
| 29 | +def rmsnorm_matmul(lhs, rhs): |
| 30 | + """Fused RMSNorm + Matmul: output = RMSNorm(lhs) @ rhs |
42 | 31 |
|
43 | | -def generate_kernel(M: int, K: int, N: int) -> tuple[str, str, dict[str, tuple[int, ...]]]: |
44 | | - """Generate RMSNorm + Matmul kernel code. |
| 32 | + RMSNorm(x) = x / sqrt(mean(x^2) + epsilon) |
45 | 33 |
|
46 | 34 | Args: |
47 | | - M: Number of rows in LHS matrix |
48 | | - K: Number of columns in LHS / rows in RHS matrix |
49 | | - N: Number of columns in RHS matrix |
| 35 | + lhs: Input tensor of shape (M, K) |
| 36 | + rhs: Weight tensor of shape (K, N) |
50 | 37 |
|
51 | 38 | Returns: |
52 | | - Tuple of (kernel_path, kernel_name, input_tensor_shapes) |
| 39 | + Output tensor of shape (M, N) |
53 | 40 | """ |
54 | | - input_tensor_shapes = {"lhs_hbm": (M, K), "rhs_hbm": (K, N)} |
55 | | - |
56 | | - rmsnorm_matmul_graph = ComputeGraph( |
57 | | - operators=[ |
58 | | - Activation(dest="lhs_square", op="np.square", data="lhs", reduce_op="np.add", reduce_res="lhs_sum_square"), |
59 | | - TensorScalar( |
60 | | - dest="rmsnorm_factor", |
61 | | - data="lhs_sum_square", |
62 | | - op0="np.multiply", |
63 | | - operand0=1 / K, |
64 | | - op1="np.add", |
65 | | - operand1=RMSNORM_EPSILON, |
66 | | - ), |
67 | | - Activation(dest="rmsnorm_factor", op="nl.rsqrt", data="rmsnorm_factor"), |
68 | | - TensorScalar(dest="lhs_norm", data="lhs", op0="np.multiply", operand0="rmsnorm_factor"), |
69 | | - Matmul(dest="output", lhs="lhs_norm", rhs="rhs", lhs_transposed=False), |
70 | | - ], |
71 | | - input_shapes={"lhs": (M, K), "rhs": (K, N)}, |
72 | | - output="output", |
73 | | - ) |
74 | | - save_graph(rmsnorm_matmul_graph, output_dir=f"{cache_root}", title="RMSNorm + Matmul") |
75 | | - kernel_name = "rmsnorm_matmul_kernel" |
76 | | - kernel_path = f"{cache_root}/{kernel_name}.py" |
77 | | - kernel_code = NKICodegen(rmsnorm_matmul_graph, kernel_name).code |
78 | | - with open(kernel_path, "w") as f: |
79 | | - f.write(kernel_code) |
80 | | - |
81 | | - return kernel_path, kernel_name, input_tensor_shapes |
82 | | - |
83 | | - |
84 | | -def run_benchmark( |
85 | | - kernel_path: str, kernel_name: str, input_tensor_shapes: dict[str, tuple[int, ...]], mac_count: int |
86 | | -) -> None: |
87 | | - """Benchmark the generated kernel. |
88 | | -
|
89 | | - Args: |
90 | | - kernel_path: Path to the generated kernel file |
91 | | - kernel_name: Name of the kernel function |
92 | | - input_tensor_shapes: Dict mapping input names to shapes |
93 | | - mac_count: Number of multiply-accumulate operations for MFU calculation |
94 | | - """ |
95 | | - jobs = ProfileJobs(cache_root_dir=cache_root, target_instance_family="trn2") |
96 | | - jobs.add_job( |
97 | | - kernel=(kernel_path, kernel_name), |
98 | | - input_tensor_shapes=input_tensor_shapes, |
99 | | - data_type=np.float32, |
100 | | - kernel_kwargs={}, |
101 | | - compiler_flags="--auto-cast=none --internal-tensorizer-opt-level=nki", |
102 | | - postprocessing=rmsnorm_gemm_correctness, |
103 | | - mac_count=mac_count, |
104 | | - ) |
105 | | - benchmark = Benchmark(jobs=jobs) |
106 | | - benchmark() |
| 41 | + K_dim = lhs.shape[-1] |
| 42 | + lhs_square = np.square(lhs) |
| 43 | + lhs_sum_square = np.sum(lhs_square, axis=-1, keepdims=True) |
| 44 | + rmsnorm_factor = 1.0 / np.sqrt(lhs_sum_square / K_dim + RMSNORM_EPSILON) |
| 45 | + lhs_norm = lhs * rmsnorm_factor |
| 46 | + return np.matmul(lhs_norm, rhs) |
107 | 47 |
|
108 | 48 |
|
109 | 49 | if __name__ == "__main__": |
110 | | - M, K, N = 256, 128, 128 |
111 | | - kernel_path, kernel_name, input_tensor_shapes = generate_kernel(M, K, N) |
112 | | - # run_benchmark(kernel_path, kernel_name, input_tensor_shapes, mac_count=M * N * K) |
| 50 | + mlir_module = rmsnorm_matmul.to_mlir() |
| 51 | + logger.info(mlir_module) |
0 commit comments