Skip to content

Commit 178616a

Browse files
drbhdanieldk
andauthored
Prefer ruff formatting (#462)
* feat: prefer ruff formatting * fix: run formatting * fix: bump local and ci ruff version and reformat * fix: bump python_depends check * fix: run griffe with uvx * feat: add make quality check * nix: use ruff 0.15.10 --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
1 parent 4d97181 commit 178616a

32 files changed

Lines changed: 231 additions & 559 deletions

.github/workflows/lint.yml

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,39 @@ jobs:
66
runs-on: ubuntu-latest
77
steps:
88
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
9-
- name: Run ruff
9+
- name: Run ruff check
1010
uses: astral-sh/ruff-action@4919ec5cf1f49eff0871dbcea0da843445b837e6 # v3
11-
src: >-
12-
kernels
11+
with:
12+
src: kernels
13+
version: "0.15.10"
14+
- name: Run ruff format check
15+
uses: astral-sh/ruff-action@4919ec5cf1f49eff0871dbcea0da843445b837e6 # v3
16+
with:
17+
src: kernels
18+
version: "0.15.10"
19+
args: format --check
1320

14-
black:
15-
name: Run black check
21+
griffe:
22+
name: Check API compatibility
1623
runs-on: ubuntu-latest
1724
env:
1825
UV_PYTHON_PREFERENCE: only-managed
1926
steps:
2027
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
21-
28+
with:
29+
fetch-depth: 0
2230
- name: Install uv and set the python version
2331
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
2432
with:
25-
python-version: 3.12
26-
27-
- name: Install black
28-
run: uv pip install black
29-
30-
- name: Check formatting
31-
run: |
32-
uv run black --check kernels
33+
python-version: "3.12"
34+
- name: Check for breaking changes
35+
run: uvx griffe check kernels --search kernels/src -a main
3336

3437
validate-dependencies:
3538
name: Validate python_depends.json
3639
runs-on: ubuntu-latest
3740
steps:
3841
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
3942

40-
- name: Set up Python
41-
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6
42-
with:
43-
python-version: "3.12"
44-
4543
- name: Validate python_depends.json is up-to-date
46-
run: |
47-
python ( cd kernels && update_python_depends.py --validate ) || {
48-
echo "Error: python_depends.json is out of date."
49-
echo "Please run: python update_python_depends.py"
50-
exit 1
51-
}
44+
run: diff kernels-data/src/python_dependencies.json kernels/src/kernels/python_depends.json

Makefile

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
.PHONY: style kernel-builder-cli-docs
1+
.PHONY: style kernel-builder-cli-docs quality
2+
23

34
export check_dirs := kernels/src kernels/tests
45

56
all: src/kernels/python_depends.json
67

7-
kernels/src/kernels/python_depends.json: kernel-builder/src/python_dependencies.json
8+
kernels/src/kernels/python_depends.json: kernels-data/src/python_dependencies.json
89
cp $< $@
910

1011
style:
11-
black ${check_dirs}
12-
isort ${check_dirs}
12+
ruff format ${check_dirs}
1313
ruff check ${check_dirs} --fix
1414

1515
kernel-builder-cli-docs:
@@ -20,3 +20,7 @@ kernel-builder-cli-docs:
2020
| sed '/`--backends/,/^\*/{/^ Default value:/d;}' \
2121
> docs/source/builder-cli.md
2222
@echo "Generated docs/source/builder-cli.md"
23+
24+
quality:
25+
ruff format --check ${check_dirs}
26+
ruff check ${check_dirs}

kernels/pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ kernels = "kernels.cli:main"
5656
[tool.setuptools.package-data]
5757
kernels = ["python_depends.json"]
5858

59-
[tool.isort]
60-
profile = "black"
61-
line_length = 119
62-
6359
[tool.ruff]
60+
# If the version is changed, apply the change in the Nix overlay
61+
# as well.
62+
required-version = "==0.15.10"
6463
exclude = [
6564
".eggs",
6665
".git",
@@ -85,4 +84,6 @@ line-length = 119
8584
# Ignored rules:
8685
# "E501" -> line length violation
8786
lint.ignore = ["E501"]
88-
lint.select = ["E", "F", "W"]
87+
lint.select = ["E", "F", "I", "W"]
88+
89+
[tool.ruff.format]

kernels/src/kernels/_versions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int | str) -> GitRef
8282
accepted_versions = sorted(requirement.filter(versions_old.keys()))
8383

8484
if len(accepted_versions) == 0:
85-
raise ValueError(
86-
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
87-
)
85+
raise ValueError(f"No version of `{repo_id}` satisfies requirement: {version_spec}")
8886

8987
return versions_old[accepted_versions[-1]]
9088

kernels/src/kernels/backends.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,7 @@ def _select_backend(backend: str | None) -> Backend:
241241
if backend in supported:
242242
return supported[backend]
243243

244-
raise ValueError(
245-
f"Invalid backend '{backend}', system supported backends: {', '.join(sorted(supported.keys()))}"
246-
)
244+
raise ValueError(f"Invalid backend '{backend}', system supported backends: {', '.join(sorted(supported.keys()))}")
247245

248246

249247
def _supported_backends() -> dict[str, Backend]:
@@ -267,9 +265,7 @@ def _get_cuda() -> Optional[CUDA]:
267265
runtime_version = ctypes.c_int(0)
268266
result = libcudart.cudaRuntimeGetVersion(ctypes.byref(runtime_version))
269267
if result != 0:
270-
warnings.warn(
271-
"System has CUDA runtime library, but cannot get runtime version."
272-
)
268+
warnings.warn("System has CUDA runtime library, but cannot get runtime version.")
273269
return None
274270

275271
# cudaRuntimeGetVersion encodes the version as (major * 1000 + minor * 10).

kernels/src/kernels/benchmarks/attention.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@ def _reference_attention(query, key, value, causal=False):
1414
"""Reference implementation using PyTorch SDPA."""
1515
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
1616
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
17-
out = torch.nn.functional.scaled_dot_product_attention(
18-
query, key, value, is_causal=causal
19-
)
17+
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
2018
return out.transpose(1, 2).contiguous()
2119

2220

2321
def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False):
2422
"""Reference implementation for variable length attention."""
2523
batch_size = cu_seqlens_q.shape[0] - 1
2624
total_tokens_q = q.shape[0]
27-
out = torch.zeros(
28-
(total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype
29-
)
25+
out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
3026

3127
for b in range(batch_size):
3228
start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
@@ -54,9 +50,7 @@ def setup_small(self):
5450
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
5551

5652
def benchmark_small(self):
57-
self.out = _extract_output(
58-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False)
59-
)
53+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False))
6054

6155
def verify_small(self) -> torch.Tensor:
6256
return _reference_attention(self.q, self.k, self.v, causal=False)
@@ -70,9 +64,7 @@ def setup_medium(self):
7064
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
7165

7266
def benchmark_medium(self):
73-
self.out = _extract_output(
74-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False)
75-
)
67+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False))
7668

7769
def verify_medium(self) -> torch.Tensor:
7870
return _reference_attention(self.q, self.k, self.v, causal=False)
@@ -86,9 +78,7 @@ def setup_large(self):
8678
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
8779

8880
def benchmark_large(self):
89-
self.out = _extract_output(
90-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False)
91-
)
81+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=False))
9282

9383
def verify_large(self) -> torch.Tensor:
9484
return _reference_attention(self.q, self.k, self.v, causal=False)
@@ -106,9 +96,7 @@ def setup_small(self):
10696
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
10797

10898
def benchmark_small(self):
109-
self.out = _extract_output(
110-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True)
111-
)
99+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True))
112100

113101
def verify_small(self) -> torch.Tensor:
114102
return _reference_attention(self.q, self.k, self.v, causal=True)
@@ -122,9 +110,7 @@ def setup_medium(self):
122110
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
123111

124112
def benchmark_medium(self):
125-
self.out = _extract_output(
126-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True)
127-
)
113+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True))
128114

129115
def verify_medium(self) -> torch.Tensor:
130116
return _reference_attention(self.q, self.k, self.v, causal=True)
@@ -138,9 +124,7 @@ def setup_large(self):
138124
self.out = torch.empty(B, S, H, D, device="cuda", dtype=torch.float16)
139125

140126
def benchmark_large(self):
141-
self.out = _extract_output(
142-
self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True)
143-
)
127+
self.out = _extract_output(self.kernel.flash_attn_func(self.q, self.k, self.v, causal=True))
144128

145129
def verify_large(self) -> torch.Tensor:
146130
return _reference_attention(self.q, self.k, self.v, causal=True)
@@ -180,9 +164,7 @@ def benchmark_small(self):
180164
)
181165

182166
def verify_small(self) -> torch.Tensor:
183-
return _varlen_reference_attention(
184-
self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
185-
)
167+
return _varlen_reference_attention(self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False)
186168

187169
# Workload: medium (5 sequences, max_seqlen=256)
188170
def setup_medium(self):
@@ -214,9 +196,7 @@ def benchmark_medium(self):
214196
)
215197

216198
def verify_medium(self) -> torch.Tensor:
217-
return _varlen_reference_attention(
218-
self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
219-
)
199+
return _varlen_reference_attention(self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False)
220200

221201
# Workload: large (8 sequences, max_seqlen=512)
222202
def setup_large(self):
@@ -248,6 +228,4 @@ def benchmark_large(self):
248228
)
249229

250230
def verify_large(self) -> torch.Tensor:
251-
return _varlen_reference_attention(
252-
self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
253-
)
231+
return _varlen_reference_attention(self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False)

kernels/src/kernels/benchmarks/layer_norm.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def benchmark_small(self):
129129
)[0].view(self.B, self.S, self.D)
130130

131131
def verify_small(self) -> torch.Tensor:
132-
return torch.nn.functional.layer_norm(
133-
self.x, [self.D], self.weight, eps=self.eps
134-
)
132+
return torch.nn.functional.layer_norm(self.x, [self.D], self.weight, eps=self.eps)
135133

136134
# Workload: medium (B=4, S=512, D=2048)
137135
def setup_medium(self):
@@ -160,9 +158,7 @@ def benchmark_medium(self):
160158
)[0].view(self.B, self.S, self.D)
161159

162160
def verify_medium(self) -> torch.Tensor:
163-
return torch.nn.functional.layer_norm(
164-
self.x, [self.D], self.weight, eps=self.eps
165-
)
161+
return torch.nn.functional.layer_norm(self.x, [self.D], self.weight, eps=self.eps)
166162

167163
# Workload: large (B=8, S=1024, D=4096)
168164
def setup_large(self):
@@ -191,6 +187,4 @@ def benchmark_large(self):
191187
)[0].view(self.B, self.S, self.D)
192188

193189
def verify_large(self) -> torch.Tensor:
194-
return torch.nn.functional.layer_norm(
195-
self.x, [self.D], self.weight, eps=self.eps
196-
)
190+
return torch.nn.functional.layer_norm(self.x, [self.D], self.weight, eps=self.eps)

kernels/src/kernels/cli/__init__.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616

1717
def main():
18-
parser = argparse.ArgumentParser(
19-
prog="kernel", description="Manage compute kernels"
20-
)
18+
parser = argparse.ArgumentParser(prog="kernel", description="Manage compute kernels")
2119
subparsers = parser.add_subparsers(required=True)
2220

2321
check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
@@ -29,12 +27,8 @@ def main():
2927
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
3028
)
3129
check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
32-
check_parser.add_argument(
33-
"--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
34-
)
35-
check_parser.add_argument(
36-
"--python-abi", type=str, help="Python ABI version", default="3.9"
37-
)
30+
check_parser.add_argument("--manylinux", type=str, help="Manylinux version", default="manylinux_2_28")
31+
check_parser.add_argument("--python-abi", type=str, help="Python ABI version", default="3.9")
3832
check_parser.set_defaults(
3933
func=lambda args: check_kernel(
4034
macos=args.macos,
@@ -107,12 +101,8 @@ def main():
107101
type=str,
108102
help="Kernel repo ID (e.g., kernels-community/activation)",
109103
)
110-
benchmark_parser.add_argument(
111-
"--branch", type=str, help="Kernel branch to benchmark"
112-
)
113-
benchmark_parser.add_argument(
114-
"--version", type=int, help="Kernel version to benchmark"
115-
)
104+
benchmark_parser.add_argument("--branch", type=str, help="Kernel branch to benchmark")
105+
benchmark_parser.add_argument("--version", type=int, help="Kernel version to benchmark")
116106
benchmark_parser.add_argument(
117107
"--output",
118108
type=str,
@@ -230,9 +220,7 @@ def default(self, o):
230220
return super().default(o)
231221

232222

233-
def check_kernel(
234-
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
235-
):
223+
def check_kernel(*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str):
236224
try:
237225
from kernels.cli import check
238226
except ImportError:

0 commit comments

Comments
 (0)