diff --git a/light_the_torch/_cli.py b/light_the_torch/_cli.py index 2515ec7..c3b5227 100644 --- a/light_the_torch/_cli.py +++ b/light_the_torch/_cli.py @@ -1,5 +1,5 @@ -from pip._internal.cli.main import main as pip_main +from pip._internal.cli.main import main -from ._patch import patch +from ._patch import patch_pip_main -main = patch(pip_main) +main = patch_pip_main(main) diff --git a/light_the_torch/_compat.py b/light_the_torch/_compat.py new file mode 100644 index 0000000..7aeb237 --- /dev/null +++ b/light_the_torch/_compat.py @@ -0,0 +1,8 @@ +import sys + +if sys.version_info >= (3, 8): + import importlib.metadata as importlib_metadata +else: + import importlib_metadata + +__all__ = ["importlib_metadata"] diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch.py deleted file mode 100644 index a28b3d8..0000000 --- a/light_the_torch/_patch.py +++ /dev/null @@ -1,375 +0,0 @@ -import contextlib -import dataclasses -import enum -import functools -import itertools -import optparse -import os -import re -import sys -import unittest.mock -from typing import List, Set -from unittest import mock - -import pip._internal.cli.cmdoptions - -from pip._internal.index.collector import CollectedSources -from pip._internal.index.package_finder import CandidateEvaluator -from pip._internal.index.sources import build_source -from pip._internal.models.search_scope import SearchScope - -import light_the_torch as ltt - -from . import _cb as cb - -from ._utils import apply_fn_patch - - -class Channel(enum.Enum): - STABLE = enum.auto() - TEST = enum.auto() - NIGHTLY = enum.auto() - LTS = enum.auto() - - @classmethod - def from_str(cls, string): - return cls[string.upper()] - - -PYTORCH_DISTRIBUTIONS = { - "torch", - "torch_model_archiver", - "torch_tb_profiler", - "torcharrow", - "torchaudio", - "torchcsprng", - "torchdata", - "torchdistx", - "torchserve", - "torchtext", - "torchvision", -} - - -def patch(pip_main): - @functools.wraps(pip_main) - def wrapper(argv=None): - if argv is None: - argv = sys.argv[1:] - - with apply_patches(argv): - return pip_main(argv) - - return wrapper - - -# adapted from https://stackoverflow.com/a/9307174 -class PassThroughOptionParser(optparse.OptionParser): - def __init__(self): - super().__init__(add_help_option=False) - - def _process_args(self, largs, rargs, values): - while rargs: - try: - super()._process_args(largs, rargs, values) - except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error: - largs.append(error.opt_str) - - -@dataclasses.dataclass -class LttOptions: - computation_backends: Set[cb.ComputationBackend] = dataclasses.field( - default_factory=lambda: {cb.CPUBackend()} - ) - channel: Channel = Channel.STABLE - - @staticmethod - def computation_backend_parser_options(): - return [ - optparse.Option( - "--pytorch-computation-backend", - help=( - "Computation backend for compiled PyTorch distributions, " - "e.g. 'cu102', 'cu115', or 'cpu'. " - "Multiple computation backends can be passed as a comma-separated " - "list, e.g 'cu102,cu113,cu116'. " - "If not specified, the computation backend is detected from the " - "available hardware, preferring CUDA over CPU." - ), - ), - optparse.Option( - "--cpuonly", - action="store_true", - help=( - "Shortcut for '--pytorch-computation-backend=cpu'. " - "If '--computation-backend' is used simultaneously, " - "it takes precedence over '--cpuonly'." - ), - ), - ] - - @staticmethod - def channel_parser_option() -> optparse.Option: - return optparse.Option( - "--pytorch-channel", - help=( - "Channel to download PyTorch distributions from, e.g. 'stable' , " - "'test', 'nightly' and 'lts'. " - "If not specified, defaults to 'stable' unless '--pre' is given in " - "which case it defaults to 'test'." - ), - ) - - @staticmethod - def _parse(argv): - parser = PassThroughOptionParser() - - for option in LttOptions.computation_backend_parser_options(): - parser.add_option(option) - parser.add_option(LttOptions.channel_parser_option()) - parser.add_option("--pre", dest="pre", action="store_true") - - opts, _ = parser.parse_args(argv) - return opts - - @classmethod - def from_pip_argv(cls, argv: List[str]): - if not argv or argv[0] != "install": - return cls() - - opts = cls._parse(argv) - - if opts.pytorch_computation_backend is not None: - cbs = { - cb.ComputationBackend.from_str(string.strip()) - for string in opts.pytorch_computation_backend.split(",") - } - elif opts.cpuonly: - cbs = {cb.CPUBackend()} - elif "LTT_PYTORCH_COMPUTATION_BACKEND" in os.environ: - cbs = { - cb.ComputationBackend.from_str(string.strip()) - for string in os.environ["LTT_PYTORCH_COMPUTATION_BACKEND"].split(",") - } - else: - cbs = cb.detect_compatible_computation_backends() - - if opts.pytorch_channel is not None: - channel = Channel.from_str(opts.pytorch_channel) - elif opts.pre: - channel = Channel.TEST - else: - channel = Channel.STABLE - - return cls(cbs, channel) - - -@contextlib.contextmanager -def apply_patches(argv): - options = LttOptions.from_pip_argv(argv) - - patches = [ - patch_cli_version(), - patch_cli_options(), - patch_link_collection(options.computation_backends, options.channel), - patch_candidate_selection(options.computation_backends), - ] - - with contextlib.ExitStack() as stack: - for patch in patches: - stack.enter_context(patch) - - yield stack - - -@contextlib.contextmanager -def patch_cli_version(): - with apply_fn_patch( - "pip", - "_internal", - "cli", - "main_parser", - "get_pip_version", - postprocessing=lambda input, output: f"ltt {ltt.__version__} from {ltt.__path__[0]}\n{output}", - ): - yield - - -@contextlib.contextmanager -def patch_cli_options(): - def postprocessing(input, output): - for option in LttOptions.computation_backend_parser_options(): - input.cmd_opts.add_option(option) - - index_group = pip._internal.cli.cmdoptions.index_group - - with apply_fn_patch( - "pip", - "_internal", - "cli", - "cmdoptions", - "add_target_python_options", - postprocessing=postprocessing, - ): - with unittest.mock.patch.dict(index_group): - options = index_group["options"].copy() - options.append(LttOptions.channel_parser_option) - index_group["options"] = options - yield - - -def get_extra_index_urls(computation_backends, channel): - if channel == Channel.STABLE: - channel_paths = [""] - elif channel == Channel.LTS: - channel_paths = [ - f"lts/{major}.{minor}/" - for major, minor in [ - (1, 8), - ] - ] - else: - channel_paths = [f"{channel.name.lower()}/"] - return [ - f"https://download.pytorch.org/whl/{channel_path}{backend}" - for channel_path, backend in itertools.product( - channel_paths, sorted(computation_backends) - ) - ] - - -@contextlib.contextmanager -def patch_link_collection(computation_backends, channel): - search_scope = SearchScope( - find_links=[], - index_urls=get_extra_index_urls(computation_backends, channel), - no_index=False, - ) - - @contextlib.contextmanager - def context(input): - if input.project_name not in PYTORCH_DISTRIBUTIONS: - yield - return - - with mock.patch.object(input.self, "search_scope", search_scope): - yield - - def postprocessing(input, output): - if input.project_name not in PYTORCH_DISTRIBUTIONS: - return output - - if channel != Channel.STABLE: - return output - - # Some stable binaries are not hosted on the PyTorch indices. We check if this - # is the case for the current distribution. - for remote_file_source in output.index_urls: - candidates = list(remote_file_source.page_candidates()) - - # Cache the candidates, so `pip` doesn't has to retrieve them again later. - remote_file_source.page_candidates = lambda: iter(candidates) - - # If there are any candidates on the PyTorch indices, we continue normally. - if candidates: - return output - - # In case the distribution is not present on the PyTorch indices, we fall back - # to PyPI. - _, pypi_file_source = build_source( - SearchScope( - find_links=[], - index_urls=["https://pypi.org/simple"], - no_index=False, - ).get_index_urls_locations(input.project_name)[0], - candidates_from_page=input.candidates_from_page, - page_validator=input.self.session.is_secure_origin, - expand_dir=False, - cache_link_parsing=False, - ) - - return CollectedSources(find_links=[], index_urls=[pypi_file_source]) - - with apply_fn_patch( - "pip", - "_internal", - "index", - "collector", - "LinkCollector", - "collect_sources", - context=context, - postprocessing=postprocessing, - ): - yield - - -@contextlib.contextmanager -def patch_candidate_selection(computation_backends): - computation_backend_pattern = re.compile( - r"/(?P(cpu|cu\d+|rocm([\d.]+)))/" - ) - - def extract_local_specifier(candidate): - local = candidate.version.local - - if local is None: - match = computation_backend_pattern.search(candidate.link.path) - local = match["computation_backend"] if match else "any" - - # Early PyTorch distributions used the "any" local specifier to indicate a - # pure Python binary. This was changed to no local specifier later. - # Setting this to "cpu" is technically not correct as it will exclude this - # binary if a non-CPU backend is requested. Still, this is probably the - # right thing to do, since the user requested a specific backend and - # although this binary will work with it, it was not compiled against it. - if local == "any": - local = "cpu" - - return local - - def preprocessing(input): - if not input.candidates: - return - - candidates = iter(input.candidates) - candidate = next(candidates) - - if candidate.name not in PYTORCH_DISTRIBUTIONS: - # At this stage all candidates have the same name. Thus, if the first is - # not a PyTorch distribution, we don't need to check the rest and can - # return without changes. - return - - input.candidates = [ - candidate - for candidate in itertools.chain([candidate], candidates) - if extract_local_specifier(candidate) in computation_backends - ] - - vanilla_sort_key = CandidateEvaluator._sort_key - - def patched_sort_key(candidate_evaluator, candidate): - # At this stage all candidates have the same name. Thus, we don't need to - # mirror the exact key structure that the vanilla sort keys have. - return ( - vanilla_sort_key(candidate_evaluator, candidate) - if candidate.name not in PYTORCH_DISTRIBUTIONS - else ( - cb.ComputationBackend.from_str(extract_local_specifier(candidate)), - candidate.version.base_version, - ) - ) - - with apply_fn_patch( - "pip", - "_internal", - "index", - "package_finder", - "CandidateEvaluator", - "get_applicable_candidates", - preprocessing=preprocessing, - ): - with unittest.mock.patch.object( - CandidateEvaluator, "_sort_key", new=patched_sort_key - ): - yield diff --git a/light_the_torch/_patch/__init__.py b/light_the_torch/_patch/__init__.py new file mode 100644 index 0000000..7909415 --- /dev/null +++ b/light_the_torch/_patch/__init__.py @@ -0,0 +1 @@ +from .patch import patch_pip_main diff --git a/light_the_torch/_patch/cli.py b/light_the_torch/_patch/cli.py new file mode 100644 index 0000000..fe86294 --- /dev/null +++ b/light_the_torch/_patch/cli.py @@ -0,0 +1,118 @@ +import dataclasses +import enum +import optparse +import os +from typing import List, Set + +import light_the_torch._cb as cb + + +class Channel(enum.Enum): + STABLE = enum.auto() + TEST = enum.auto() + NIGHTLY = enum.auto() + + @classmethod + def from_str(cls, string): + return cls[string.upper()] + + +# adapted from https://stackoverflow.com/a/9307174 +class PassThroughOptionParser(optparse.OptionParser): + def __init__(self): + super().__init__(add_help_option=False) + + def _process_args(self, largs, rargs, values): + while rargs: + try: + super()._process_args(largs, rargs, values) + except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error: + largs.append(error.opt_str) + + +@dataclasses.dataclass +class LttOptions: + computation_backends: Set[cb.ComputationBackend] = dataclasses.field( + default_factory=lambda: {cb.CPUBackend()} + ) + channel: Channel = Channel.STABLE + + @staticmethod + def computation_backend_parser_options(): + return [ + optparse.Option( + "--pytorch-computation-backend", + help=( + "Computation backend for compiled PyTorch distributions, " + "e.g. 'cu102', 'cu115', or 'cpu'. " + "Multiple computation backends can be passed as a comma-separated " + "list, e.g 'cu102,cu113,cu116'. " + "If not specified, the computation backend is detected from the " + "available hardware, preferring CUDA over CPU." + ), + ), + optparse.Option( + "--cpuonly", + action="store_true", + help=( + "Shortcut for '--pytorch-computation-backend=cpu'. " + "If '--computation-backend' is used simultaneously, " + "it takes precedence over '--cpuonly'." + ), + ), + ] + + @staticmethod + def channel_parser_option() -> optparse.Option: + return optparse.Option( + "--pytorch-channel", + help=( + "Channel to download PyTorch distributions from, e.g. 'stable' , " + "'test', 'nightly' and 'lts'. " + "If not specified, defaults to 'stable' unless '--pre' is given in " + "which case it defaults to 'test'." + ), + ) + + @staticmethod + def _parse(argv): + parser = PassThroughOptionParser() + + for option in LttOptions.computation_backend_parser_options(): + parser.add_option(option) + parser.add_option(LttOptions.channel_parser_option()) + parser.add_option("--pre", dest="pre", action="store_true") + + opts, _ = parser.parse_args(argv) + return opts + + @classmethod + def from_pip_argv(cls, argv: List[str]): + if not argv or argv[0] != "install": + return cls() + + opts = cls._parse(argv) + + if opts.pytorch_computation_backend is not None: + cbs = { + cb.ComputationBackend.from_str(string.strip()) + for string in opts.pytorch_computation_backend.split(",") + } + elif opts.cpuonly: + cbs = {cb.CPUBackend()} + elif "LTT_PYTORCH_COMPUTATION_BACKEND" in os.environ: + cbs = { + cb.ComputationBackend.from_str(string.strip()) + for string in os.environ["LTT_PYTORCH_COMPUTATION_BACKEND"].split(",") + } + else: + cbs = cb.detect_compatible_computation_backends() + + if opts.pytorch_channel is not None: + channel = Channel.from_str(opts.pytorch_channel) + elif opts.pre: + channel = Channel.TEST + else: + channel = Channel.STABLE + + return cls(cbs, channel) diff --git a/light_the_torch/_patch/packages.py b/light_the_torch/_patch/packages.py new file mode 100644 index 0000000..44c11e3 --- /dev/null +++ b/light_the_torch/_patch/packages.py @@ -0,0 +1,120 @@ +import abc +import dataclasses +import itertools +import re + +from pip._internal.models.search_scope import SearchScope + +import light_the_torch._cb as cb + +from .cli import Channel + +__all__ = ["packages"] + + +@dataclasses.dataclass +class _Package(abc.ABC): + name: str + + @abc.abstractmethod + def make_search_scope(self, options): + pass + + @abc.abstractmethod + def filter_candidates(self, candidates, options): + pass + + @abc.abstractmethod + def make_sort_key(self, candidate, options): + pass + + +packages = {} + + +class _PyTorchDistribution(_Package): + def _get_extra_index_urls(self, computation_backends, channel): + if channel == Channel.STABLE: + channel_paths = [""] + else: + channel_paths = [f"{channel.name.lower()}/"] + return [ + f"https://download.pytorch.org/whl/{channel_path}{backend}" + for channel_path, backend in itertools.product( + channel_paths, sorted(computation_backends) + ) + ] + + def make_search_scope(self, options): + return SearchScope( + find_links=[], + index_urls=self._get_extra_index_urls( + options.computation_backends, options.channel + ), + no_index=False, + ) + + _COMPUTATION_BACKEND_PATTERN = re.compile( + r"/(?P(cpu|cu\d+|rocm([\d.]+)))/" + ) + + def _extract_local_specifier(self, candidate): + local = candidate.version.local + + if local is None: + match = self._COMPUTATION_BACKEND_PATTERN.search(candidate.link.path) + local = match["computation_backend"] if match else "any" + + # Early PyTorch distributions used the "any" local specifier to indicate a + # pure Python binary. This was changed to no local specifier later. + # Setting this to "cpu" is technically not correct as it will exclude this + # binary if a non-CPU backend is requested. Still, this is probably the + # right thing to do, since the user requested a specific backend and + # although this binary will work with it, it was not compiled against it. + if local == "any": + local = "cpu" + + return local + + def filter_candidates(self, candidates, options): + return [ + candidate + for candidate in candidates + if self._extract_local_specifier(candidate) in options.computation_backends + ] + + def make_sort_key(self, candidate, options): + return ( + cb.ComputationBackend.from_str(self._extract_local_specifier(candidate)), + candidate.version.base_version, + ) + + +# FIXME: check whether all of these are hosted on all channels +# If not, change `_TorchData` below to a more general class +# FIXME: check if they are valid at all +for name in { + "torch", + "torch_model_archiver", + "torch_tb_profiler", + "torcharrow", + "torchaudio", + "torchcsprng", + "torchdistx", + "torchserve", + "torchtext", + "torchvision", +}: + packages[name] = _PyTorchDistribution(name) + + +class _TorchData(_PyTorchDistribution): + def make_search_scope(self, options): + if options.channel == Channel.STABLE: + return SearchScope( + find_links=[], + index_urls=["https://pypi.org/simple"], + no_index=False, + ) + + return super().make_search_scope(options) diff --git a/light_the_torch/_patch/patch.py b/light_the_torch/_patch/patch.py new file mode 100644 index 0000000..20ba1c3 --- /dev/null +++ b/light_the_torch/_patch/patch.py @@ -0,0 +1,151 @@ +import contextlib +import functools +import sys +import unittest.mock +from unittest import mock + +import pip._internal.cli.cmdoptions +from pip._internal.index.package_finder import CandidateEvaluator + +import light_the_torch as ltt +from .cli import LttOptions +from .packages import packages +from .utils import apply_fn_patch + + +def patch_pip_main(pip_main): + @functools.wraps(pip_main) + def wrapper(argv=None): + if argv is None: + argv = sys.argv[1:] + + with apply_patches(argv): + return pip_main(argv) + + return wrapper + + +@contextlib.contextmanager +def apply_patches(argv): + options = LttOptions.from_pip_argv(argv) + + patches = [ + patch_cli_version(), + patch_cli_options(), + patch_link_collection(packages, options), + patch_candidate_selection(packages, options), + ] + + with contextlib.ExitStack() as stack: + for patch in patches: + stack.enter_context(patch) + + yield stack + + +@contextlib.contextmanager +def patch_cli_version(): + with apply_fn_patch( + "pip", + "_internal", + "cli", + "main_parser", + "get_pip_version", + postprocessing=lambda input, output: f"ltt {ltt.__version__} from {ltt.__path__[0]}\n{output}", + ): + yield + + +@contextlib.contextmanager +def patch_cli_options(): + def postprocessing(input, output): + for option in LttOptions.computation_backend_parser_options(): + input.cmd_opts.add_option(option) + + index_group = pip._internal.cli.cmdoptions.index_group + + with apply_fn_patch( + "pip", + "_internal", + "cli", + "cmdoptions", + "add_target_python_options", + postprocessing=postprocessing, + ): + with unittest.mock.patch.dict(index_group): + options = index_group["options"].copy() + options.append(LttOptions.channel_parser_option) + index_group["options"] = options + yield + + +@contextlib.contextmanager +def patch_link_collection(packages, options): + @contextlib.contextmanager + def context(input): + package = packages.get(input.project_name) + if not package: + yield + return + + with mock.patch.object( + input.self, "search_scope", package.make_search_scope(options) + ): + yield + + with apply_fn_patch( + "pip", + "_internal", + "index", + "collector", + "LinkCollector", + "collect_sources", + context=context, + ): + yield + + +@contextlib.contextmanager +def patch_candidate_selection(packages, options): + def preprocessing(input): + if not input.candidates: + return + + # At this stage all candidates have the same name. Thus, if the first is + # not a PyTorch distribution, we don't need to check the rest and can + # return without changes. + package = packages.get(input.candidates[0].name) + if not package: + return + + input.candidates = list(package.filter_candidates(input.candidates, options)) + + def patched_sort_key(candidate_evaluator, candidate): + package = packages.get(candidate.name) + assert package + return package.make_sort_key(candidate, options) + + @contextlib.contextmanager + def context(input): + # At this stage all candidates have the same name. Thus, we don't need to + # mirror the exact key structure that the vanilla sort keys have. + if not input.candidates or input.candidates[0].name not in packages: + yield + return + + with unittest.mock.patch.object( + CandidateEvaluator, "_sort_key", new=patched_sort_key + ): + yield + + with apply_fn_patch( + "pip", + "_internal", + "index", + "package_finder", + "CandidateEvaluator", + "get_applicable_candidates", + preprocessing=preprocessing, + context=context, + ): + yield diff --git a/light_the_torch/_utils.py b/light_the_torch/_patch/utils.py similarity index 57% rename from light_the_torch/_utils.py rename to light_the_torch/_patch/utils.py index 19fa064..8cc3b3b 100644 --- a/light_the_torch/_utils.py +++ b/light_the_torch/_patch/utils.py @@ -6,17 +6,31 @@ from unittest import mock +from pip._vendor.packaging.requirements import Requirement -class InternalError(RuntimeError): - def __init__(self) -> None: - # TODO: check against pip version - # TODO: fix wording - msg = ( - "Unexpected internal pytorch-pip-shim error. If you ever encounter this " - "message during normal operation, please submit a bug report at " - "https://github.com/pmeier/pytorch-pip-shim/issues" +from light_the_torch._compat import importlib_metadata + + +class UnexpectedInternalError(Exception): + def __init__(self, msg) -> None: + actual_pip_version = Requirement(f"pip=={importlib_metadata.version('pip')}") + required_pip_version = next( + requirement + for requirement in ( + Requirement(requirement_string) + for requirement_string in importlib_metadata.requires("light_the_torch") + ) + if requirement.name == "pip" + ) + super().__init__( + f"{msg}\n\n" + f"This can happen when the actual pip version (`{actual_pip_version}`) " + f"and the one required by light-the-torch (`{required_pip_version}`) " + f"are out of sync.\n" + f"If that is the case, please reinstall light-the-torch. " + f"Otherwise, please submit a bug report at " + f"https://github.com/pmeier/light-the-torch/issues" ) - super().__init__(msg) class Input(dict): @@ -77,7 +91,7 @@ def apply_fn_patch( postprocessing=lambda input, output: output, ): target = ".".join(parts) - fn = import_fn(target) + fn = import_obj(target) @functools.wraps(fn) def new(*args, **kwargs): @@ -93,7 +107,7 @@ def new(*args, **kwargs): yield -def import_fn(target: str): +def import_obj(target: str): attrs = [] name = target while name: @@ -101,13 +115,25 @@ def import_fn(target: str): module = importlib.import_module(name) break except ImportError: - name, attr = name.rsplit(".", 1) - attrs.append(attr) + try: + name, attr = name.rsplit(".", 1) + except ValueError: + attr = name + name = "" + attrs.insert(0, attr) else: - raise InternalError + raise UnexpectedInternalError( + f"Tried to import `{target}`, " + f"but the top-level namespace `{attrs[0]}` doesn't seem to be a module." + ) obj = module - for attr in attrs[::-1]: - obj = getattr(obj, attr) + for attr in attrs: + try: + obj = getattr(obj, attr) + except AttributeError: + raise UnexpectedInternalError( + f"Failed to access `{attr}` from `{obj.__name__}`" + ) from None return obj diff --git a/tests/test_cli.py b/tests/test_cli.py index 766d05b..ef53b15 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -77,7 +77,7 @@ def check_fn(text): @pytest.fixture def set_argv(mocker): def patch(*options): - return mocker.patch.object(sys, "argv", ["ltt", *options]) + return mocker.patch_pip_main.object(sys, "argv", ["ltt", *options]) return patch diff --git a/tests/test_computation_backend.py b/tests/test_computation_backend.py index b26bca3..95b050b 100644 --- a/tests/test_computation_backend.py +++ b/tests/test_computation_backend.py @@ -152,7 +152,7 @@ def test_cuda_vs_rocm(self): @pytest.fixture def patch_nvidia_driver_version(mocker): def factory(version): - return mocker.patch( + return mocker.patch_pip_main( "light_the_torch._cb.subprocess.run", return_value=SimpleNamespace(stdout=f"driver_version\n{version}"), ) @@ -208,7 +208,7 @@ def cuda_backends_params(): class TestDetectCompatibleComputationBackends: def test_no_nvidia_driver(self, mocker): - mocker.patch( + mocker.patch_pip_main( "light_the_torch._cb.subprocess.run", side_effect=subprocess.CalledProcessError(1, ""), ) @@ -224,7 +224,9 @@ def test_cuda_backends( nvidia_driver_version, compatible_cuda_backends, ): - mocker.patch("light_the_torch._cb.platform.system", return_value=system) + mocker.patch_pip_main( + "light_the_torch._cb.platform.system", return_value=system + ) patch_nvidia_driver_version(nvidia_driver_version) backends = cb.detect_compatible_computation_backends() diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 3a69e0a..d9444a3 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -103,12 +103,12 @@ def patched_import(name, globals, locals, fromlist, level): return __import__(name, globals, locals, fromlist, level) - mocker.patch.object(builtins, "__import__", new=patched_import) + mocker.patch_pip_main.object(builtins, "__import__", new=patched_import) values = { name: module for name, module in sys.modules.items() if retain_condition(name) } - mocker.patch.dict(sys.modules, clear=True, values=values) + mocker.patch_pip_main.dict(sys.modules, clear=True, values=values) def test_version_not_installed(mocker):