Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,22 @@ TileOPs is a spec-driven GPU operator platform built on TileLang. Every operator
The platform consists of 8 modules:

```mermaid
graph TD
graph LR
M1["M1: Spec<br/>ops_manifest.yaml"]
M2["M2: Op + Kernel<br/>user-facing API + TileLang GPU kernels"]
M3["M3: Correctness<br/>tests/"]
M4["M4: Benchmark<br/>raw time"]
M5["M5: Roofline<br/>efficiency"]
M6["M6: HW Profile<br/>GPU parameters"]
M7["M7: CI Gate<br/>correctness + perf regression"]
M8["M8: Docs<br/>auto-generated"]
M2["M2: Op + Kernel"]
M3["M3: Correctness"]
M4["M4: Benchmark"]
M5["M5: Roofline"]
M6["M6: HW Profile"]
M7["M7: CI Gate"]
M8["M8: Docs"]

M1 -- defines --> M2
M2 --> M3
M2 --> M4
M2 -- docstring --> M8
M2 --> M3 & M4
M3 & M4 --> M7
M4 -- raw time --> M5
M6 --> M5
M5 --> M8
M1 --> M8
M3 --> M7
M4 --> M7
M1 & M2 & M5 & M7 --> M8
```

| Module | Responsibility | Key Artifact |
Expand Down
93 changes: 51 additions & 42 deletions docs/manifest.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,90 @@ ops_manifest.yaml (spec)

## Fields

Each manifest entry has four sections: signature, workloads, roofline, and source.
Each manifest entry lives under the top-level `ops:` key and has a `family` field plus four sections: signature, workloads, roofline, and source.

`family` is an explicit, machine-readable field for doc generation, API grouping, and tooling — not derived from source paths, which can change.

### Signature

Declares inputs, outputs, parameters, and optional shape rules:

```yaml
rmsnorm_fwd:
family: norm

signature:
inputs:
x: {dtype: "float16 | bfloat16"}
weight: {dtype: "same_as(x)"}
outputs:
y: {dtype: "same_as(x)"}
params:
dim: {type: int, default: -1}
eps: {type: float, default: 1e-6}
shape_rules:
- "weight.shape == (x.shape[dim],)"
- "y.shape == x.shape"
ops:
rmsnorm_fwd:
family: norm

signature:
inputs:
x: {dtype: "float16 | bfloat16"}
weight: {dtype: "same_as(x)"}
outputs:
y: {dtype: "same_as(x)"}
params:
dim: {type: int, default: -1}
eps: {type: float, default: 1e-6}
shape_rules:
- "weight.shape == (x.shape[dim],)"
- "y.shape == x.shape"
```

Conventions:

- **Tensor rank is unconstrained** — DNN tensors can be 1D, 2D, 3D, etc.
- **Signature uses dict, not list.** Name is identity — making it a key (`x: {dtype: ...}`) is more concise than list items (`- name: x`).
- **No per-tensor shape.** Tensor rank is intentionally unconstrained (DNN ops accept 1D, 2D, 3D, etc.). Shape relationships are expressed through `shape_rules`, not by fixing shape on each tensor.
- **`dtype`** uses `|` for alternatives, `same_as(x)` for dependent types. Concrete entries may list dtypes explicitly.
- **`shape_rules`** use Python expression syntax, are optional and best-effort.
- **`dtype`** uses `|` for alternatives, `same_as(x)` for dependent types.
- **Params declare the full interface.** If an op mathematically supports a parameter (e.g., `dim` for norm), it belongs in the manifest even if the current kernel only supports the default value.

### Workloads

Concrete shape/dtype combinations for benchmarking, based on real model architectures:
Representative shape/dtype combinations for benchmarking:

```yaml
workloads:
# Llama-3.1-8B (hidden=4096)
- {x_shape: [1, 4096, 4096], dim: -1, dtypes: [float16, bfloat16]}
- {x_shape: [32, 1, 4096], dim: -1, dtypes: [bfloat16]}
# Llama-3.1-70B (hidden=8192)
- {x_shape: [1, 4096, 8192], dim: -1, dtypes: [float16, bfloat16]}
- {x_shape: [32, 1, 8192], dim: -1, dtypes: [bfloat16]}
# Llama-3.1-405B (hidden=16384)
- {x_shape: [1, 2048, 16384], dim: -1, dtypes: [float16, bfloat16]}
- {x_shape: [32, 1, 16384], dim: -1, dtypes: [bfloat16]}
workloads:
- x_shape: [2048, 4096]
dtypes: [float16, bfloat16]
label: "llama-3.1-8b-prefill"
- x_shape: [1, 4096]
dtypes: [bfloat16]
label: "llama-3.1-8b-decode"
```

Shapes are chosen by the op developer based on target model architectures. No centralized shape source is mandated.
- `x_shape` and `dtypes` are required — they drive benchmark execution and code generation. Kernel-level parameters (e.g., `M`/`N`) are derivable from shapes and should not be repeated.
- `label` is optional — a human-readable tag for reports and dashboards. Tools auto-generate from shape + dtype when omitted.
- Op-specific parameters (e.g., `dim` for norm, `causal` for attention) can be added per workload entry.

Shapes are chosen by the op developer based on target model architectures.

### Roofline

Two modes — inline expression for simple ops, Python function reference for complex ops:

```yaml
# Simple op: inline
roofline:
flops: "2 * M * N"
bytes: "(2 * M * N + N) * sizeof(dtype)"

# Complex op (e.g., flash attention): function reference
roofline:
func: "tileops.perf.formulas.gqa_prefill_fwd"
roofline:
flops: "4 * M * N"
bytes: "2 * (M * N + N + M * N)"
```

```yaml
roofline:
func: "tileops.perf.formulas.gqa_prefill_fwd"
```

Referenced functions live in `tileops/perf/formulas.py` and return `{"flops": int, "bytes": int}`.

The field is `bytes` (total bytes moved), not `memory` — maps directly to `bytes_moved` in the roofline formula `memory_time = bytes_moved / hbm_bandwidth`.

### Source

Pointers to implementation files for navigation and CI validation:

```yaml
source:
kernel: "tileops/kernels/norm/rms_norm.py"
op: "tileops/ops/norm/rms_norm.py"
test: "tests/ops/test_rms_norm.py"
source:
kernel: tileops/kernels/norm/rms_norm.py
op: tileops/ops/norm/rms_norm.py
test: tests/ops/test_rms_norm.py
bench: benchmarks/ops/bench_rms_norm.py
```

## What Is NOT in the Manifest
Expand Down
77 changes: 77 additions & 0 deletions ops_manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# ops_manifest.yaml -- Spec-driven op registry for TileOPs
#
# Schema
# ------
# ops:
# <op_name>: # Unique op identifier (e.g. rmsnorm_fwd)
# signature:
# inputs: [{name, dtype, shape}] # Input tensors
# outputs: [{name, dtype, shape}] # Output tensors
# params: [{name, type, default}] # Scalar / config parameters
# shape_rules: [str] # Optional Python expressions relating dimensions
# workloads: [{x_shape, dtypes, label?}] # Representative shapes for benchmarking
# roofline: # Analytical cost model
# flops: <expr> # Inline Python expression -OR-
# bytes: <expr> # both flops and bytes expressions
# func: <module:function> # Alternative: reference to Python function
# source:
# kernel: <path> # Path to kernel implementation
# op: <path> # Path to op wrapper
# test: <path> # Path to test file
# bench: <path> # Path to benchmark file
# family: <str> # Op family for grouping (e.g. norm, attention)
#
# Notes:
# - Backward ops are registered as independent entries (e.g. rmsnorm_bwd).
# - shape_rules use Python expression syntax and are optional.
# - roofline supports two modes: inline expressions (flops/bytes) or func.

ops:
rmsnorm_fwd:
family: norm

signature:
inputs:
- name: x
dtype: "{float16, bfloat16}"
shape: "[M, N]"
- name: weight
dtype: "{float16, bfloat16}"
shape: "[N]"
outputs:
- name: y
dtype: "{float16, bfloat16}"
shape: "[M, N]"
params:
- name: dim
type: int
default: -1
- name: eps
type: float
default: 1.0e-6
shape_rules:
- "weight.shape == (x.shape[-1],)"
- "y.shape == x.shape"

workloads:
# Llama-3.1-8B (hidden_dim=4096)
- {x_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "llama-3.1-8b-prefill"}
- {x_shape: [1, 4096], dtypes: [bfloat16], label: "llama-3.1-8b-decode"}
# Llama-3.1-70B (hidden_dim=8192)
- {x_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "llama-3.1-70b-prefill"}
- {x_shape: [1, 8192], dtypes: [bfloat16], label: "llama-3.1-70b-decode"}
# Llama-3.1-405B (hidden_dim=16384)
- {x_shape: [2048, 16384], dtypes: [float16, bfloat16], label: "llama-3.1-405b-prefill"}
- {x_shape: [1, 16384], dtypes: [bfloat16], label: "llama-3.1-405b-decode"}

roofline:
# Per row: N squares + (N-1) adds + div + add + rsqrt + N muls (normalize) + N muls (weight) ≈ 4N
flops: "4 * M * N"
# Bytes: read x (M*N) + read weight (N) + write y (M*N), x2 for fp16/bf16 elem_size
bytes: "2 * (M * N + N + M * N)"

source:
kernel: tileops/kernels/norm/rms_norm.py
op: tileops/ops/norm/rms_norm.py
test: tests/ops/test_rms_norm.py
bench: benchmarks/ops/bench_rms_norm.py
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"codespell==2.4.1",
"pytest>=8.0",
"pytest-xdist>=3.0",
"pyyaml>=6.0",
]

[build-system]
Expand Down
102 changes: 102 additions & 0 deletions tests/test_ops_manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Schema validation for ops_manifest.yaml.

Validates structural invariants across all ops in the manifest.
Not op-specific — tests apply to every entry.
"""

from pathlib import Path

import pytest
import yaml

pytestmark = pytest.mark.smoke

REPO_ROOT = Path(__file__).resolve().parent.parent
MANIFEST_PATH = REPO_ROOT / "ops_manifest.yaml"


@pytest.fixture(scope="module")
def manifest():
"""Load and parse the ops manifest."""
assert MANIFEST_PATH.exists(), f"ops_manifest.yaml not found at {MANIFEST_PATH}"
with open(MANIFEST_PATH) as f:
data = yaml.safe_load(f)
assert isinstance(data, dict), "Manifest root must be a YAML mapping"
return data


@pytest.fixture(scope="module")
def all_ops(manifest):
"""Return the ops dict from the manifest."""
assert "ops" in manifest, "Manifest must have top-level 'ops' key"
assert isinstance(manifest["ops"], dict)
return manifest["ops"]


class TestManifestStructure:
"""Manifest file exists, parses, and has the expected top-level structure."""

def test_manifest_exists(self):
assert MANIFEST_PATH.exists()

def test_manifest_is_valid_yaml(self, manifest):
assert manifest is not None

def test_has_ops_key(self, manifest):
assert "ops" in manifest
assert isinstance(manifest["ops"], dict)


class TestOpSchema:
"""Every op entry has the required fields and valid sub-structure."""

REQUIRED_TOP_FIELDS = {"family", "signature", "workloads", "roofline", "source"}

def test_every_op_has_required_fields(self, all_ops):
for op_name, entry in all_ops.items():
missing = self.REQUIRED_TOP_FIELDS - set(entry.keys())
assert not missing, f"{op_name} missing fields: {missing}"

def test_every_signature_has_inputs_and_outputs(self, all_ops):
for op_name, entry in all_ops.items():
sig = entry["signature"]
assert "inputs" in sig, f"{op_name}: signature missing 'inputs'"
assert isinstance(sig["inputs"], list), f"{op_name}: inputs must be a list"
assert len(sig["inputs"]) >= 1, f"{op_name}: must have at least 1 input"
assert "outputs" in sig, f"{op_name}: signature missing 'outputs'"
assert isinstance(sig["outputs"], list), f"{op_name}: outputs must be a list"

def test_every_roofline_has_valid_mode(self, all_ops):
for op_name, entry in all_ops.items():
roofline = entry["roofline"]
has_inline = "flops" in roofline and "bytes" in roofline
has_func = "func" in roofline
assert has_inline or has_func, (
f"{op_name}: roofline must have (flops + bytes) or func"
)

def test_shape_rules_are_valid_expressions(self, all_ops):
for op_name, entry in all_ops.items():
sig = entry["signature"]
if "shape_rules" not in sig:
continue
for rule in sig["shape_rules"]:
try:
compile(rule, "<shape_rule>", "eval")
except SyntaxError as exc:
pytest.fail(
f"{op_name}: invalid shape_rule: {rule!r} ({exc})"
)


class TestSourcePaths:
"""All source paths point to existing files."""

def test_all_source_paths_exist(self, all_ops):
for op_name, entry in all_ops.items():
source = entry["source"]
for key, rel_path in source.items():
full_path = REPO_ROOT / rel_path
assert full_path.exists(), (
f"{op_name}: source.{key} not found: {rel_path}"
)
Loading