Skip to content

Commit 42ebd1c

Browse files
authored
Add local patch prefix kpack-split host wheel extras (#4870)
Applies ROCm/rocm-systems#5514 locally until landed upstream to have the same extras straight from the start even though it adds a patch to TheRock.
1 parent 45c793b commit 42ebd1c

1 file changed

Lines changed: 250 additions & 0 deletions

File tree

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
From f7cdfcf9fd6e319a5add6e30e86bffa2ec076af8 Mon Sep 17 00:00:00 2001
2+
From: Marius Brehler <marius.brehler@amd.com>
3+
Date: Mon, 27 Apr 2026 23:58:48 +0200
4+
Subject: [PATCH] Prefix kpack-split host wheel extras with `device-`
5+
6+
The kpack-split host wheel METADATA today emits per-target extras as
7+
the bare bundle key, so a single find-links URL exposes two different
8+
conventions for the same SDK + PyTorch stack:
9+
10+
pip install rocm[device-gfx1201] # rocm-sdk-core
11+
pip install torch[gfx1201] # kpack-split host wheel
12+
13+
This aligns the kpack splitter to emit the `device-<arch>` pattern used
14+
in rocm-sdk-core so both halves share one install command.
15+
16+
Apply the prefix uniformly: the per-target extras
17+
(`device-gfx1201`, ...) and the aggregate (`device-all`) all start
18+
with `device-` after this commit and widens the regex in
19+
`_add_variant_markers_to_metadata` to match the new `[\w-]+` extra
20+
names, and strips the `device-` prefix before substituting the bare arch
21+
into the PEP 817 `"amd :: gfx_arch :: <arch>"` marker.
22+
23+
Co-Authored-By: Claude <noreply@anthropic.com>
24+
---
25+
.../kpack/python/rocm_kpack/wheel_splitter.py | 47 ++++++++++++-------
26+
shared/kpack/tests/test_wheel_splitter.py | 42 ++++++++---------
27+
2 files changed, 51 insertions(+), 38 deletions(-)
28+
29+
diff --git a/shared/kpack/python/rocm_kpack/wheel_splitter.py b/shared/kpack/python/rocm_kpack/wheel_splitter.py
30+
index c96b327c20..ecbb1dbb5a 100644
31+
--- a/shared/kpack/python/rocm_kpack/wheel_splitter.py
32+
+++ b/shared/kpack/python/rocm_kpack/wheel_splitter.py
33+
@@ -869,7 +869,7 @@ class WheelSplitter:
34+
Inserts into the existing METADATA header section:
35+
- Requires-Dist: rocm-bootstrap
36+
- Per-target Provides-Extra + Requires-Dist (extras-based)
37+
- - Provides-Extra: all (with all device wheels)
38+
+ - Provides-Extra: device-all (with all device wheels)
39+
- If include_variant_markers: also adds variant_properties markers
40+
"""
41+
metadata_path = host_staging / identity.dist_info_name / "METADATA"
42+
@@ -889,11 +889,15 @@ class WheelSplitter:
43+
if bundle.level == rocm_bootstrap.PackagingLevel.TARGET:
44+
target_keys.append(key)
45+
46+
- # For each target, compute packaging chain and generate extras
47+
+ # For each target, compute packaging chain and generate extras.
48+
+ # Per-target extras are prefixed with "device-" to mirror the
49+
+ # convention used by rocm-sdk-core (rocm[device-gfx1201], ...) so
50+
+ # one find-links URL exposes a uniform install command across the
51+
+ # SDK + PyTorch stack.
52+
all_device_dist_names: set[str] = set()
53+
- for target_name in target_keys:
54+
- chain = rocm_bootstrap.packaging_chain(target_name)
55+
- lines.append(f"Provides-Extra: {target_name}")
56+
+ for amdgpu_target in target_keys:
57+
+ chain = rocm_bootstrap.packaging_chain(amdgpu_target)
58+
+ lines.append(f"Provides-Extra: device-{amdgpu_target}")
59+
60+
for chain_bundle in chain:
61+
if chain_bundle.key not in bundle_keys:
62+
@@ -907,15 +911,15 @@ class WheelSplitter:
63+
# update the regex there too.
64+
lines.append(
65+
f"Requires-Dist: {dist_name} == {version}; "
66+
- f'extra == "{target_name}"'
67+
+ f'extra == "device-{amdgpu_target}"'
68+
)
69+
if include_variant_markers:
70+
lines.append(
71+
f"Requires-Dist: {dist_name} == {version}; "
72+
- f'"amd :: gfx_arch :: {target_name}" in variant_properties'
73+
+ f'"amd :: gfx_arch :: {amdgpu_target}" in variant_properties'
74+
)
75+
76+
- # Also add non-target bundle keys (family, sub-family) to the "all" set
77+
+ # Also add non-target bundle keys (family, sub-family) to the "device-all" set
78+
for key in sorted(bundle_keys):
79+
bundle = rocm_bootstrap.lookup_bundle(key)
80+
if bundle.level != rocm_bootstrap.PackagingLevel.TARGET:
81+
@@ -924,10 +928,14 @@ class WheelSplitter:
82+
)
83+
all_device_dist_names.add(dist_name)
84+
85+
- # "all" extra: includes every device wheel
86+
- lines.append("Provides-Extra: all")
87+
+ # "device-all" extra: includes every device wheel. The prefix
88+
+ # matches the per-target extras above so every kpack-emitted extra
89+
+ # starts with "device-".
90+
+ lines.append("Provides-Extra: device-all")
91+
for dist_name in sorted(all_device_dist_names):
92+
- lines.append(f"Requires-Dist: {dist_name} == {version}; " f'extra == "all"')
93+
+ lines.append(
94+
+ f"Requires-Dist: {dist_name} == {version}; " f'extra == "device-all"'
95+
+ )
96+
97+
# Insert new headers before the body (after the last header line).
98+
# METADATA uses RFC 822 format: headers, blank line, body.
99+
@@ -946,7 +954,7 @@ class WheelSplitter:
100+
metadata_path.write_text(new_content, encoding="utf-8")
101+
102+
if self.verbose:
103+
- n_extras = len(target_keys) + 1 # +1 for "all"
104+
+ n_extras = len(target_keys) + 1 # +1 for "device-all"
105+
n_device = len(all_device_dist_names)
106+
label = "variant" if include_variant_markers else "classic"
107+
print(
108+
@@ -1047,20 +1055,25 @@ class WheelSplitter:
109+
# add the variant_properties version
110+
if not line.startswith("Requires-Dist:") or "extra ==" not in line:
111+
continue
112+
- if 'extra == "all"' in line:
113+
+ # The "device-all" extra aggregates every device wheel and
114+
+ # must not get a variant marker — variants are per-target.
115+
+ if 'extra == "device-all"' in line:
116+
continue
117+
# Extract the dist requirement and target name.
118+
# This regex must match the format generated by _rewrite_host_metadata():
119+
# f"Requires-Dist: {dist_name} == {version}; "
120+
- # f'extra == "{target_name}"'
121+
+ # f'extra == "device-{amdgpu_target}"'
122+
# If that format changes, this regex must be updated to match.
123+
- match = re.match(r'(Requires-Dist: .+ == .+); extra == "(\w+)"', line)
124+
+ # The captured extra is the user-facing "device-<arch>" form;
125+
+ # strip the "device-" prefix to recover the bare gfx target
126+
+ # used in the variant_properties marker.
127+
+ match = re.match(r'(Requires-Dist: .+ == .+); extra == "([\w-]+)"', line)
128+
if match:
129+
req_part = match.group(1)
130+
- target_name = match.group(2)
131+
+ amdgpu_target = match.group(2).removeprefix("device-")
132+
new_lines.append(
133+
f"{req_part}; "
134+
- f'"amd :: gfx_arch :: {target_name}" in variant_properties'
135+
+ f'"amd :: gfx_arch :: {amdgpu_target}" in variant_properties'
136+
)
137+
138+
metadata_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
139+
diff --git a/shared/kpack/tests/test_wheel_splitter.py b/shared/kpack/tests/test_wheel_splitter.py
140+
index 8c4f5686cb..4b098e9899 100644
141+
--- a/shared/kpack/tests/test_wheel_splitter.py
142+
+++ b/shared/kpack/tests/test_wheel_splitter.py
143+
@@ -386,12 +386,12 @@ class TestRewriteHostMetadata:
144+
splitter._rewrite_host_metadata(staging, identity, {"gfx942"})
145+
146+
metadata = (staging / identity.dist_info_name / "METADATA").read_text()
147+
- assert "Provides-Extra: gfx942" in metadata
148+
+ assert "Provides-Extra: device-gfx942" in metadata
149+
assert (
150+
- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "gfx942"'
151+
+ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-gfx942"'
152+
in metadata
153+
)
154+
- assert "Provides-Extra: all" in metadata
155+
+ assert "Provides-Extra: device-all" in metadata
156+
157+
def test_classic_has_no_variant_markers(self, tmp_path: Path):
158+
staging, identity = self._make_host_staging(tmp_path)
159+
@@ -424,20 +424,20 @@ class TestRewriteHostMetadata:
160+
metadata = (staging / identity.dist_info_name / "METADATA").read_text()
161+
# Both should appear as deps under the gfx1100 extra
162+
assert (
163+
- 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "gfx1100"'
164+
+ 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "device-gfx1100"'
165+
in metadata
166+
)
167+
assert (
168+
- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "gfx1100"'
169+
+ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-gfx1100"'
170+
in metadata
171+
)
172+
- # "all" should include both
173+
+ # "device-all" should include both
174+
assert (
175+
- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "all"'
176+
+ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-all"'
177+
in metadata
178+
)
179+
assert (
180+
- 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "all"'
181+
+ 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "device-all"'
182+
in metadata
183+
)
184+
185+
@@ -450,11 +450,11 @@ class TestRewriteHostMetadata:
186+
187+
metadata = (staging / identity.dist_info_name / "METADATA").read_text()
188+
# No target-level extras since gfx11 is a family
189+
- assert "Provides-Extra: gfx11" not in metadata
190+
- # But "all" should still include it
191+
- assert "Provides-Extra: all" in metadata
192+
+ assert "Provides-Extra: device-gfx11" not in metadata
193+
+ # But "device-all" should still include it
194+
+ assert "Provides-Extra: device-all" in metadata
195+
assert (
196+
- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "all"'
197+
+ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-all"'
198+
in metadata
199+
)
200+
201+
@@ -481,8 +481,8 @@ class TestRewriteHostMetadata:
202+
header_section, body_section = header_body
203+
# All injected headers must be in the header section, not the body
204+
assert "Requires-Dist: rocm-bootstrap" in header_section
205+
- assert "Provides-Extra: gfx942" in header_section
206+
- assert "Provides-Extra: all" in header_section
207+
+ assert "Provides-Extra: device-gfx942" in header_section
208+
+ assert "Provides-Extra: device-all" in header_section
209+
# Body should still contain the description
210+
assert "description body" in body_section
211+
212+
@@ -501,10 +501,10 @@ class TestVariantWheel:
213+
"Name: torch\n"
214+
"Version: 2.10.0+rocm7.1\n"
215+
"Requires-Dist: rocm-bootstrap\n"
216+
- "Provides-Extra: gfx942\n"
217+
- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "gfx942"\n'
218+
- "Provides-Extra: all\n"
219+
- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "all"\n'
220+
+ "Provides-Extra: device-gfx942\n"
221+
+ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-gfx942"\n'
222+
+ "Provides-Extra: device-all\n"
223+
+ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-all"\n'
224+
"\n"
225+
"Description body.\n"
226+
)
227+
@@ -542,9 +542,9 @@ class TestVariantWheel:
228+
229+
metadata = (staging / identity.dist_info_name / "METADATA").read_text()
230+
# Should have the extras line AND the variant marker line
231+
- assert 'extra == "gfx942"' in metadata
232+
+ assert 'extra == "device-gfx942"' in metadata
233+
assert ('"amd :: gfx_arch :: gfx942" in variant_properties') in metadata
234+
- # "all" extra should NOT get variant markers
235+
+ # "device-all" extra should NOT get variant markers
236+
assert '"amd :: gfx_arch :: all"' not in metadata
237+
238+
def test_variant_json(self, tmp_path: Path):
239+
@@ -582,7 +582,7 @@ class TestVariantWheel:
240+
# Check METADATA has variant markers
241+
metadata = (variant_path / identity.dist_info_name / "METADATA").read_text()
242+
assert "variant_properties" in metadata
243+
- assert 'extra == "gfx942"' in metadata
244+
+ assert 'extra == "device-gfx942"' in metadata
245+
246+
# Check variant.json exists
247+
variant_json = variant_path / identity.dist_info_name / "variant.json"
248+
--
249+
2.51.0
250+

0 commit comments

Comments
 (0)