-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathops_manifest.yaml
More file actions
77 lines (72 loc) · 3.07 KB
/
ops_manifest.yaml
File metadata and controls
77 lines (72 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# ops_manifest.yaml -- Spec-driven op registry for TileOPs
#
# Schema
# ------
# ops:
# <op_name>: # Unique op identifier (e.g. rmsnorm_fwd)
# signature:
# inputs: [{name, dtype, shape}] # Input tensors
# outputs: [{name, dtype, shape}] # Output tensors
# params: [{name, type, default}] # Scalar / config parameters
# shape_rules: [str] # Optional Python expressions relating dimensions
# workloads: [{x_shape, dtypes, label?}] # Representative shapes for benchmarking
# roofline: # Analytical cost model
# flops: <expr> # Inline Python expression -OR-
# bytes: <expr> # both flops and bytes expressions
# func: <module:function> # Alternative: reference to Python function
# source:
# kernel: <path> # Path to kernel implementation
# op: <path> # Path to op wrapper
# test: <path> # Path to test file
# bench: <path> # Path to benchmark file
# family: <str> # Op family for grouping (e.g. norm, attention)
#
# Notes:
# - Backward ops are registered as independent entries (e.g. rmsnorm_bwd).
# - shape_rules use Python expression syntax and are optional.
# - roofline supports two modes: inline expressions (flops/bytes) or func.
ops:
rmsnorm_fwd:
family: norm
signature:
inputs:
- name: x
dtype: "{float16, bfloat16}"
shape: "[M, N]"
- name: weight
dtype: "{float16, bfloat16}"
shape: "[N]"
outputs:
- name: y
dtype: "{float16, bfloat16}"
shape: "[M, N]"
params:
- name: dim
type: int
default: -1
- name: eps
type: float
default: 1.0e-6
shape_rules:
- "weight.shape == (x.shape[-1],)"
- "y.shape == x.shape"
workloads:
# Llama-3.1-8B (hidden_dim=4096)
- {x_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "llama-3.1-8b-prefill"}
- {x_shape: [1, 4096], dtypes: [bfloat16], label: "llama-3.1-8b-decode"}
# Llama-3.1-70B (hidden_dim=8192)
- {x_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "llama-3.1-70b-prefill"}
- {x_shape: [1, 8192], dtypes: [bfloat16], label: "llama-3.1-70b-decode"}
# Llama-3.1-405B (hidden_dim=16384)
- {x_shape: [2048, 16384], dtypes: [float16, bfloat16], label: "llama-3.1-405b-prefill"}
- {x_shape: [1, 16384], dtypes: [bfloat16], label: "llama-3.1-405b-decode"}
roofline:
# Per row: N squares + (N-1) adds + div + add + rsqrt + N muls (normalize) + N muls (weight) ≈ 4N
flops: "4 * M * N"
# Bytes: read x (M*N) + read weight (N) + write y (M*N), x2 for fp16/bf16 elem_size
bytes: "2 * (M * N + N + M * N)"
source:
kernel: tileops/kernels/norm/rms_norm.py
op: tileops/ops/norm/rms_norm.py
test: tests/ops/test_rms_norm.py
bench: benchmarks/ops/bench_rms_norm.py