|
| 1 | +From f7cdfcf9fd6e319a5add6e30e86bffa2ec076af8 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Marius Brehler <marius.brehler@amd.com> |
| 3 | +Date: Mon, 27 Apr 2026 23:58:48 +0200 |
| 4 | +Subject: [PATCH] Prefix kpack-split host wheel extras with `device-` |
| 5 | + |
| 6 | +The kpack-split host wheel METADATA today emits per-target extras as |
| 7 | +the bare bundle key, so a single find-links URL exposes two different |
| 8 | +conventions for the same SDK + PyTorch stack: |
| 9 | + |
| 10 | + pip install rocm[device-gfx1201] # rocm-sdk-core |
| 11 | + pip install torch[gfx1201] # kpack-split host wheel |
| 12 | + |
| 13 | +This aligns the kpack splitter to emit the `device-<arch>` pattern used |
| 14 | +in rocm-sdk-core so both halves share one install command. |
| 15 | + |
| 16 | +Apply the prefix uniformly: the per-target extras |
| 17 | +(`device-gfx1201`, ...) and the aggregate (`device-all`) all start |
| 18 | +with `device-` after this commit and widens the regex in |
| 19 | +`_add_variant_markers_to_metadata` to match the new `[\w-]+` extra |
| 20 | +names, and strips the `device-` prefix before substituting the bare arch |
| 21 | +into the PEP 817 `"amd :: gfx_arch :: <arch>"` marker. |
| 22 | + |
| 23 | +Co-Authored-By: Claude <noreply@anthropic.com> |
| 24 | +--- |
| 25 | + .../kpack/python/rocm_kpack/wheel_splitter.py | 47 ++++++++++++------- |
| 26 | + shared/kpack/tests/test_wheel_splitter.py | 42 ++++++++--------- |
| 27 | + 2 files changed, 51 insertions(+), 38 deletions(-) |
| 28 | + |
| 29 | +diff --git a/shared/kpack/python/rocm_kpack/wheel_splitter.py b/shared/kpack/python/rocm_kpack/wheel_splitter.py |
| 30 | +index c96b327c20..ecbb1dbb5a 100644 |
| 31 | +--- a/shared/kpack/python/rocm_kpack/wheel_splitter.py |
| 32 | ++++ b/shared/kpack/python/rocm_kpack/wheel_splitter.py |
| 33 | +@@ -869,7 +869,7 @@ class WheelSplitter: |
| 34 | + Inserts into the existing METADATA header section: |
| 35 | + - Requires-Dist: rocm-bootstrap |
| 36 | + - Per-target Provides-Extra + Requires-Dist (extras-based) |
| 37 | +- - Provides-Extra: all (with all device wheels) |
| 38 | ++ - Provides-Extra: device-all (with all device wheels) |
| 39 | + - If include_variant_markers: also adds variant_properties markers |
| 40 | + """ |
| 41 | + metadata_path = host_staging / identity.dist_info_name / "METADATA" |
| 42 | +@@ -889,11 +889,15 @@ class WheelSplitter: |
| 43 | + if bundle.level == rocm_bootstrap.PackagingLevel.TARGET: |
| 44 | + target_keys.append(key) |
| 45 | + |
| 46 | +- # For each target, compute packaging chain and generate extras |
| 47 | ++ # For each target, compute packaging chain and generate extras. |
| 48 | ++ # Per-target extras are prefixed with "device-" to mirror the |
| 49 | ++ # convention used by rocm-sdk-core (rocm[device-gfx1201], ...) so |
| 50 | ++ # one find-links URL exposes a uniform install command across the |
| 51 | ++ # SDK + PyTorch stack. |
| 52 | + all_device_dist_names: set[str] = set() |
| 53 | +- for target_name in target_keys: |
| 54 | +- chain = rocm_bootstrap.packaging_chain(target_name) |
| 55 | +- lines.append(f"Provides-Extra: {target_name}") |
| 56 | ++ for amdgpu_target in target_keys: |
| 57 | ++ chain = rocm_bootstrap.packaging_chain(amdgpu_target) |
| 58 | ++ lines.append(f"Provides-Extra: device-{amdgpu_target}") |
| 59 | + |
| 60 | + for chain_bundle in chain: |
| 61 | + if chain_bundle.key not in bundle_keys: |
| 62 | +@@ -907,15 +911,15 @@ class WheelSplitter: |
| 63 | + # update the regex there too. |
| 64 | + lines.append( |
| 65 | + f"Requires-Dist: {dist_name} == {version}; " |
| 66 | +- f'extra == "{target_name}"' |
| 67 | ++ f'extra == "device-{amdgpu_target}"' |
| 68 | + ) |
| 69 | + if include_variant_markers: |
| 70 | + lines.append( |
| 71 | + f"Requires-Dist: {dist_name} == {version}; " |
| 72 | +- f'"amd :: gfx_arch :: {target_name}" in variant_properties' |
| 73 | ++ f'"amd :: gfx_arch :: {amdgpu_target}" in variant_properties' |
| 74 | + ) |
| 75 | + |
| 76 | +- # Also add non-target bundle keys (family, sub-family) to the "all" set |
| 77 | ++ # Also add non-target bundle keys (family, sub-family) to the "device-all" set |
| 78 | + for key in sorted(bundle_keys): |
| 79 | + bundle = rocm_bootstrap.lookup_bundle(key) |
| 80 | + if bundle.level != rocm_bootstrap.PackagingLevel.TARGET: |
| 81 | +@@ -924,10 +928,14 @@ class WheelSplitter: |
| 82 | + ) |
| 83 | + all_device_dist_names.add(dist_name) |
| 84 | + |
| 85 | +- # "all" extra: includes every device wheel |
| 86 | +- lines.append("Provides-Extra: all") |
| 87 | ++ # "device-all" extra: includes every device wheel. The prefix |
| 88 | ++ # matches the per-target extras above so every kpack-emitted extra |
| 89 | ++ # starts with "device-". |
| 90 | ++ lines.append("Provides-Extra: device-all") |
| 91 | + for dist_name in sorted(all_device_dist_names): |
| 92 | +- lines.append(f"Requires-Dist: {dist_name} == {version}; " f'extra == "all"') |
| 93 | ++ lines.append( |
| 94 | ++ f"Requires-Dist: {dist_name} == {version}; " f'extra == "device-all"' |
| 95 | ++ ) |
| 96 | + |
| 97 | + # Insert new headers before the body (after the last header line). |
| 98 | + # METADATA uses RFC 822 format: headers, blank line, body. |
| 99 | +@@ -946,7 +954,7 @@ class WheelSplitter: |
| 100 | + metadata_path.write_text(new_content, encoding="utf-8") |
| 101 | + |
| 102 | + if self.verbose: |
| 103 | +- n_extras = len(target_keys) + 1 # +1 for "all" |
| 104 | ++ n_extras = len(target_keys) + 1 # +1 for "device-all" |
| 105 | + n_device = len(all_device_dist_names) |
| 106 | + label = "variant" if include_variant_markers else "classic" |
| 107 | + print( |
| 108 | +@@ -1047,20 +1055,25 @@ class WheelSplitter: |
| 109 | + # add the variant_properties version |
| 110 | + if not line.startswith("Requires-Dist:") or "extra ==" not in line: |
| 111 | + continue |
| 112 | +- if 'extra == "all"' in line: |
| 113 | ++ # The "device-all" extra aggregates every device wheel and |
| 114 | ++ # must not get a variant marker — variants are per-target. |
| 115 | ++ if 'extra == "device-all"' in line: |
| 116 | + continue |
| 117 | + # Extract the dist requirement and target name. |
| 118 | + # This regex must match the format generated by _rewrite_host_metadata(): |
| 119 | + # f"Requires-Dist: {dist_name} == {version}; " |
| 120 | +- # f'extra == "{target_name}"' |
| 121 | ++ # f'extra == "device-{amdgpu_target}"' |
| 122 | + # If that format changes, this regex must be updated to match. |
| 123 | +- match = re.match(r'(Requires-Dist: .+ == .+); extra == "(\w+)"', line) |
| 124 | ++ # The captured extra is the user-facing "device-<arch>" form; |
| 125 | ++ # strip the "device-" prefix to recover the bare gfx target |
| 126 | ++ # used in the variant_properties marker. |
| 127 | ++ match = re.match(r'(Requires-Dist: .+ == .+); extra == "([\w-]+)"', line) |
| 128 | + if match: |
| 129 | + req_part = match.group(1) |
| 130 | +- target_name = match.group(2) |
| 131 | ++ amdgpu_target = match.group(2).removeprefix("device-") |
| 132 | + new_lines.append( |
| 133 | + f"{req_part}; " |
| 134 | +- f'"amd :: gfx_arch :: {target_name}" in variant_properties' |
| 135 | ++ f'"amd :: gfx_arch :: {amdgpu_target}" in variant_properties' |
| 136 | + ) |
| 137 | + |
| 138 | + metadata_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8") |
| 139 | +diff --git a/shared/kpack/tests/test_wheel_splitter.py b/shared/kpack/tests/test_wheel_splitter.py |
| 140 | +index 8c4f5686cb..4b098e9899 100644 |
| 141 | +--- a/shared/kpack/tests/test_wheel_splitter.py |
| 142 | ++++ b/shared/kpack/tests/test_wheel_splitter.py |
| 143 | +@@ -386,12 +386,12 @@ class TestRewriteHostMetadata: |
| 144 | + splitter._rewrite_host_metadata(staging, identity, {"gfx942"}) |
| 145 | + |
| 146 | + metadata = (staging / identity.dist_info_name / "METADATA").read_text() |
| 147 | +- assert "Provides-Extra: gfx942" in metadata |
| 148 | ++ assert "Provides-Extra: device-gfx942" in metadata |
| 149 | + assert ( |
| 150 | +- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "gfx942"' |
| 151 | ++ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-gfx942"' |
| 152 | + in metadata |
| 153 | + ) |
| 154 | +- assert "Provides-Extra: all" in metadata |
| 155 | ++ assert "Provides-Extra: device-all" in metadata |
| 156 | + |
| 157 | + def test_classic_has_no_variant_markers(self, tmp_path: Path): |
| 158 | + staging, identity = self._make_host_staging(tmp_path) |
| 159 | +@@ -424,20 +424,20 @@ class TestRewriteHostMetadata: |
| 160 | + metadata = (staging / identity.dist_info_name / "METADATA").read_text() |
| 161 | + # Both should appear as deps under the gfx1100 extra |
| 162 | + assert ( |
| 163 | +- 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "gfx1100"' |
| 164 | ++ 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "device-gfx1100"' |
| 165 | + in metadata |
| 166 | + ) |
| 167 | + assert ( |
| 168 | +- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "gfx1100"' |
| 169 | ++ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-gfx1100"' |
| 170 | + in metadata |
| 171 | + ) |
| 172 | +- # "all" should include both |
| 173 | ++ # "device-all" should include both |
| 174 | + assert ( |
| 175 | +- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "all"' |
| 176 | ++ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-all"' |
| 177 | + in metadata |
| 178 | + ) |
| 179 | + assert ( |
| 180 | +- 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "all"' |
| 181 | ++ 'Requires-Dist: amd-torch-device-gfx1100 == 2.10.0+rocm7.1; extra == "device-all"' |
| 182 | + in metadata |
| 183 | + ) |
| 184 | + |
| 185 | +@@ -450,11 +450,11 @@ class TestRewriteHostMetadata: |
| 186 | + |
| 187 | + metadata = (staging / identity.dist_info_name / "METADATA").read_text() |
| 188 | + # No target-level extras since gfx11 is a family |
| 189 | +- assert "Provides-Extra: gfx11" not in metadata |
| 190 | +- # But "all" should still include it |
| 191 | +- assert "Provides-Extra: all" in metadata |
| 192 | ++ assert "Provides-Extra: device-gfx11" not in metadata |
| 193 | ++ # But "device-all" should still include it |
| 194 | ++ assert "Provides-Extra: device-all" in metadata |
| 195 | + assert ( |
| 196 | +- 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "all"' |
| 197 | ++ 'Requires-Dist: amd-torch-device-gfx11 == 2.10.0+rocm7.1; extra == "device-all"' |
| 198 | + in metadata |
| 199 | + ) |
| 200 | + |
| 201 | +@@ -481,8 +481,8 @@ class TestRewriteHostMetadata: |
| 202 | + header_section, body_section = header_body |
| 203 | + # All injected headers must be in the header section, not the body |
| 204 | + assert "Requires-Dist: rocm-bootstrap" in header_section |
| 205 | +- assert "Provides-Extra: gfx942" in header_section |
| 206 | +- assert "Provides-Extra: all" in header_section |
| 207 | ++ assert "Provides-Extra: device-gfx942" in header_section |
| 208 | ++ assert "Provides-Extra: device-all" in header_section |
| 209 | + # Body should still contain the description |
| 210 | + assert "description body" in body_section |
| 211 | + |
| 212 | +@@ -501,10 +501,10 @@ class TestVariantWheel: |
| 213 | + "Name: torch\n" |
| 214 | + "Version: 2.10.0+rocm7.1\n" |
| 215 | + "Requires-Dist: rocm-bootstrap\n" |
| 216 | +- "Provides-Extra: gfx942\n" |
| 217 | +- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "gfx942"\n' |
| 218 | +- "Provides-Extra: all\n" |
| 219 | +- 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "all"\n' |
| 220 | ++ "Provides-Extra: device-gfx942\n" |
| 221 | ++ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-gfx942"\n' |
| 222 | ++ "Provides-Extra: device-all\n" |
| 223 | ++ 'Requires-Dist: amd-torch-device-gfx942 == 2.10.0+rocm7.1; extra == "device-all"\n' |
| 224 | + "\n" |
| 225 | + "Description body.\n" |
| 226 | + ) |
| 227 | +@@ -542,9 +542,9 @@ class TestVariantWheel: |
| 228 | + |
| 229 | + metadata = (staging / identity.dist_info_name / "METADATA").read_text() |
| 230 | + # Should have the extras line AND the variant marker line |
| 231 | +- assert 'extra == "gfx942"' in metadata |
| 232 | ++ assert 'extra == "device-gfx942"' in metadata |
| 233 | + assert ('"amd :: gfx_arch :: gfx942" in variant_properties') in metadata |
| 234 | +- # "all" extra should NOT get variant markers |
| 235 | ++ # "device-all" extra should NOT get variant markers |
| 236 | + assert '"amd :: gfx_arch :: all"' not in metadata |
| 237 | + |
| 238 | + def test_variant_json(self, tmp_path: Path): |
| 239 | +@@ -582,7 +582,7 @@ class TestVariantWheel: |
| 240 | + # Check METADATA has variant markers |
| 241 | + metadata = (variant_path / identity.dist_info_name / "METADATA").read_text() |
| 242 | + assert "variant_properties" in metadata |
| 243 | +- assert 'extra == "gfx942"' in metadata |
| 244 | ++ assert 'extra == "device-gfx942"' in metadata |
| 245 | + |
| 246 | + # Check variant.json exists |
| 247 | + variant_json = variant_path / identity.dist_info_name / "variant.json" |
| 248 | +-- |
| 249 | +2.51.0 |
| 250 | + |
0 commit comments