Skip to content

Commit b7a26de

Browse files
authored
Refactor variant handling and add CUDA fallback (#339)
* Refactor variant handling and add CUDA fallback Build variants were stringly-typed throughout kernels, with custom parsing and serialization sprinkled everywhere. This change adds proper/strong typing to variants adding a `Variant` class. This also centers parsing/serialization in one place and allows code to easily query various parts of of a variant. This also fundamentally changes how we deal with getting variants from the Hub. Rather than casting a wide net with all possible variants and using allow patterns based on that, we query the hub for variants of a kernel, parse them and can decide if there is an applicable variant ahead of time. If there are multiple applicable variants, we can select the best one (e.g. arch before noarch or recent CUDA version before older versions). * Refactor variant handling and add CUDA fallback Build variants were stringly-typed throughout kernels, with custom parsing and serialization sprinkled everywhere. This change adds proper/strong typing to variants adding a `Variant` class. This also centers parsing/serialization in one place and allows code to easily query various parts of of a variant. This also fundamentally changes how we deal with getting variants from the Hub. Rather than casting a wide net with all possible variants and using allow patterns based on that, we query the hub for variants of a kernel, parse them and can decide if there is an applicable variant ahead of time. If there are multiple applicable variants, we can select the best one (e.g. arch before noarch or recent CUDA version before older versions). * Switch around cxx11 condition to support tagless build variant * Make `kernels versions` more informative * Move backend/variant regexes to their classes * Type fixes * Improve error handling * Add tvm-ffi variant strings for testing * Formatting
1 parent d603813 commit b7a26de

File tree

7 files changed

+851
-133
lines changed

7 files changed

+851
-133
lines changed

docs/source/cli-versions.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# kernels versions
22

3-
Use `kernels versions` to list all available versions of a kernel on the Hub.
3+
Use `kernels versions` to list all available versions of a kernel on the Hub
4+
and marks compatible versions.
45

56
## Usage
67

@@ -19,7 +20,7 @@ kernels versions kernels-community/activation
1920
## Example Output
2021

2122
```text
22-
Version 1: torch210-cu128-x86_64-windows, torch210-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch210-metal-aarch64-darwin ✅, torch27-cxx11-cu118-x86_64-linux, torch27-cxx11-cu126-x86_64-linux, torch27-cxx11-cu128-aarch64-linux, torch27-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu126-x86_64-linux, torch28-cxx11-cu128-aarch64-linux, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu129-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu126-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu128-x86_64-linux, torch29-cxx11-cu130-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch29-metal-aarch64-darwin
23+
Version 1: torch210-metal-aarch64-darwin, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch27-cxx11-cu118-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-aarch64-linux, torch27-cxx11-cu126-x86_64-linux, ✅ torch29-cxx11-cu126-x86_64-linux (compatible), torch27-cxx11-cu128-x86_64-linux, torch210-cxx11-cu126-x86_64-linux, torch29-metal-aarch64-darwin, torch27-cxx11-cu128-aarch64-linux, torch210-cu128-x86_64-windows, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu128-x86_64-linux (preferred), torch28-cxx11-cu129-x86_64-linux
2324
```
2425

2526
## Use Cases

kernels/src/kernels/backends.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import ctypes
22
import ctypes.util
3+
import re
34
import warnings
45
from dataclasses import dataclass
5-
from typing import Optional, Protocol
6+
from typing import ClassVar, Optional, Protocol
67

78
from packaging.version import Version
89

@@ -18,98 +19,172 @@ def name(self) -> str:
1819
...
1920

2021
@property
21-
def variant(self) -> str:
22+
def variant_str(self) -> str:
2223
"""
2324
The name of the backend as used in a build variant, e.g. `cu128`
2425
for CUDA 12.8.
2526
"""
2627
...
2728

2829

29-
@dataclass
30+
@dataclass(unsafe_hash=True)
3031
class CANN:
32+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)")
33+
3134
version: Version
3235

3336
@property
3437
def name(self) -> str:
3538
return "cann"
3639

3740
@property
38-
def variant(self) -> str:
41+
def variant_str(self) -> str:
3942
return f"cann{self.version.major}{self.version.minor}"
4043

44+
@staticmethod
45+
def parse(s: str) -> "CANN":
46+
m = CANN._VARIANT_REGEX.fullmatch(s)
47+
if not m:
48+
raise ValueError(f"Invalid CANN variant string: {s!r}")
49+
return CANN(version=Version(f"{m.group(1)}.{m.group(2)}"))
50+
4151

42-
@dataclass
52+
@dataclass(unsafe_hash=True)
4353
class CPU:
4454
@property
4555
def name(self) -> str:
4656
return "cpu"
4757

4858
@property
49-
def variant(self) -> str:
59+
def variant_str(self) -> str:
5060
return "cpu"
5161

62+
@staticmethod
63+
def parse(s: str) -> "CPU":
64+
if s != "cpu":
65+
raise ValueError(f"Invalid CPU variant string: {s!r}")
66+
return CPU()
67+
5268

53-
@dataclass
69+
@dataclass(unsafe_hash=True)
5470
class CUDA:
71+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)")
72+
5573
version: Version
5674

5775
@property
5876
def name(self) -> str:
5977
return "cuda"
6078

6179
@property
62-
def variant(self) -> str:
80+
def variant_str(self) -> str:
6381
return f"cu{self.version.major}{self.version.minor}"
6482

83+
@staticmethod
84+
def parse(s: str) -> "CUDA":
85+
m = CUDA._VARIANT_REGEX.fullmatch(s)
86+
if not m:
87+
raise ValueError(f"Invalid CUDA variant string: {s!r}")
88+
return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}"))
89+
6590

66-
@dataclass
91+
@dataclass(unsafe_hash=True)
6792
class Metal:
6893
@property
6994
def name(self) -> str:
7095
return "metal"
7196

7297
@property
73-
def variant(self) -> str:
98+
def variant_str(self) -> str:
7499
return "metal"
75100

101+
@staticmethod
102+
def parse(s: str) -> "Metal":
103+
if s != "metal":
104+
raise ValueError(f"Invalid Metal variant string: {s!r}")
105+
return Metal()
106+
76107

77-
@dataclass
108+
@dataclass(unsafe_hash=True)
78109
class Neuron:
79110
@property
80111
def name(self) -> str:
81112
return "neuron"
82113

83114
@property
84-
def variant(self) -> str:
115+
def variant_str(self) -> str:
85116
return "neuron"
86117

118+
@staticmethod
119+
def parse(s: str) -> "Neuron":
120+
if s != "neuron":
121+
raise ValueError(f"Invalid Neuron variant string: {s!r}")
122+
return Neuron()
87123

88-
@dataclass
124+
125+
@dataclass(unsafe_hash=True)
89126
class ROCm:
127+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)")
128+
90129
version: Version
91130

92131
@property
93132
def name(self) -> str:
94133
return "rocm"
95134

96135
@property
97-
def variant(self) -> str:
136+
def variant_str(self) -> str:
98137
return f"rocm{self.version.major}{self.version.minor}"
99138

139+
@staticmethod
140+
def parse(s: str) -> "ROCm":
141+
m = ROCm._VARIANT_REGEX.fullmatch(s)
142+
if not m:
143+
raise ValueError(f"Invalid ROCm variant string: {s!r}")
144+
return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}"))
145+
100146

101-
@dataclass
147+
@dataclass(unsafe_hash=True)
102148
class XPU:
149+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)")
150+
103151
version: Version
104152

105153
@property
106154
def name(self) -> str:
107155
return "xpu"
108156

109157
@property
110-
def variant(self) -> str:
158+
def variant_str(self) -> str:
111159
return f"xpu{self.version.major}{self.version.minor}"
112160

161+
@staticmethod
162+
def parse(s: str) -> "XPU":
163+
m = XPU._VARIANT_REGEX.fullmatch(s)
164+
if not m:
165+
raise ValueError(f"Invalid XPU variant string: {s!r}")
166+
return XPU(version=Version(f"{m.group(1)}.{m.group(2)}"))
167+
168+
169+
def parse_backend(s: str) -> Backend:
170+
"""Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend."""
171+
if s == "cpu":
172+
return CPU.parse(s)
173+
elif s == "metal":
174+
return Metal.parse(s)
175+
elif s == "neuron":
176+
return Neuron.parse(s)
177+
elif s.startswith("cu"):
178+
return CUDA.parse(s)
179+
elif s.startswith("rocm"):
180+
return ROCm.parse(s)
181+
elif s.startswith("xpu"):
182+
return XPU.parse(s)
183+
elif s.startswith("cann"):
184+
return CANN.parse(s)
185+
else:
186+
raise ValueError(f"Unknown backend variant string: {s!r}")
187+
113188

114189
def _backend() -> Backend:
115190
if has_torch:
Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,31 @@
1-
from importlib.util import find_spec
2-
from pathlib import Path
3-
4-
from huggingface_hub import HfApi
5-
61
from kernels._versions import _get_available_versions
7-
from kernels.utils import _build_variants, _get_hf_api
8-
from kernels.variants import BUILD_VARIANT_REGEX
2+
from kernels.utils import _get_hf_api
3+
from kernels.variants import (
4+
get_variants,
5+
resolve_variants,
6+
)
97

108

119
def print_kernel_versions(repo_id: str):
1210
api = _get_hf_api()
1311

14-
if find_spec("torch") is None:
15-
# Do not mark compatible variants when Torch is not available.
16-
compatible_variants = set()
17-
else:
18-
compatible_variants = set(_build_variants(None))
19-
2012
versions = _get_available_versions(repo_id).items()
2113
if not versions:
2214
print(f"Repository does not support kernel versions: {repo_id}")
2315
return
2416

2517
for version, ref in sorted(versions, key=lambda x: x[0]):
18+
variants = get_variants(api, repo_id=repo_id, revision=ref.ref)
19+
resolved = resolve_variants(variants, None)
20+
best = resolved[0] if resolved else None
21+
resolved_set = set(resolved)
2622
print(f"Version {version}: ", end="")
27-
variants = [
28-
f"{variant} ✅" if variant in compatible_variants else f"{variant}"
29-
for variant in _get_build_variants(api, repo_id, ref.ref)
23+
variant_strs = [
24+
(
25+
f"✅ {variant.variant_str} ({'compatible, preferred' if variant == best else 'compatible'})"
26+
if variant in resolved_set
27+
else f"{variant.variant_str}"
28+
)
29+
for variant in variants
3030
]
31-
print(", ".join(variants))
32-
33-
34-
def _get_build_variants(api: "HfApi", repo_id: str, revision: str) -> list[str]:
35-
variants = set()
36-
for filename in api.list_repo_files(repo_id, revision=revision):
37-
path = Path(filename)
38-
if len(path.parts) < 2 or path.parts[0] != "build":
39-
continue
40-
41-
match = BUILD_VARIANT_REGEX.match(path.parts[1])
42-
if match:
43-
variants.add(path.parts[1])
44-
return sorted(variants)
31+
print(", ".join(variant_strs))

0 commit comments

Comments
 (0)