Skip to content

Commit 8048f77

Browse files
authored
patch macOS wheels for <1.0.0 to specifiy CPU backend (#35)
* patch macOS wheels for `<1.0.0` to specifiy CPU backend * add test
1 parent 73e8fc8 commit 8048f77

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

light_the_torch/_pip/find.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def from_candidate_preferences(
213213

214214

215215
class PytorchCandidateEvaluator(CandidateEvaluator):
216+
_MACOS_PLATFORM_PATTERN = re.compile(r"macosx_\d+_\d+_x86_64")
217+
216218
def __init__(
217219
self,
218220
*args: Any,
@@ -245,22 +247,38 @@ def from_candidate_evaluator(
245247
def _sort_key(
246248
self, candidate: InstallationCandidate
247249
) -> Tuple[cb.ComputationBackend, Version]:
248-
version = Version(
249-
f"{candidate.version.major}"
250-
f".{candidate.version.minor}"
251-
f".{candidate.version.micro}"
250+
return (
251+
cb.ComputationBackend.from_str(candidate.version.local),
252+
candidate.version.base_version,
252253
)
253-
computation_backend = cb.ComputationBackend.from_str(candidate.version.local)
254-
return computation_backend, version
255254

256255
def get_applicable_candidates(
257256
self, candidates: List[InstallationCandidate]
258257
) -> List[InstallationCandidate]:
259-
return [
258+
applicable_candidates = [
260259
candidate
261260
for candidate in super().get_applicable_candidates(candidates)
262261
if candidate.version.local in self.computation_backends
263262
]
263+
if self._is_macos:
264+
self._patch_mac_lt_1_0_0_local(applicable_candidates)
265+
return applicable_candidates
266+
267+
@property
268+
def _is_macos(self) -> bool:
269+
return any(
270+
self._MACOS_PLATFORM_PATTERN.match(tag.platform)
271+
for tag in self._supported_tags
272+
)
273+
274+
def _patch_mac_lt_1_0_0_local(
275+
self, candidates: List[InstallationCandidate]
276+
) -> None:
277+
for candidate in candidates:
278+
if candidate.version.major >= 1:
279+
continue
280+
281+
candidate.version = Version(str(candidate.version).replace("any", "cpu"))
264282

265283

266284
class PytorchLinkCollector(LinkCollector):

tests/unit/_pip/test_find.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22

33
import pytest
4+
from pip._internal.models.wheel import Wheel
45
from pip._vendor.packaging.version import Version
56

67
import light_the_torch as ltt
@@ -178,3 +179,19 @@ def test_find_links_channel_smoke(channel):
178179
assert ltt.find_links(
179180
["torch"], computation_backends={cb.CPUBackend()}, channel=channel
180181
)
182+
183+
184+
@pytest.mark.parametrize("python_version", PYTHON_VERSIONS)
185+
def test_patch_mac_local_specifier_lt_1_0_0(
186+
patch_extract_dists, patch_run, python_version
187+
):
188+
# See https://github.com/pmeier/light-the-torch/issues/34
189+
dists = ["torch"]
190+
patch_extract_dists(return_value=dists)
191+
192+
links = ltt.find_links(
193+
dists, python_version=python_version, platform="macosx_10_9_x86_64"
194+
)
195+
version = Version(Wheel(links[0]).version)
196+
197+
assert version >= Version("1.0.0")

0 commit comments

Comments
 (0)