Skip to content

Commit ff1358d

Browse files
committed
Add requires_package decorator
1 parent 4dc0343 commit ff1358d

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import importlib.metadata
5+
import inspect
6+
import types
7+
from typing import Callable, TypeVar
8+
9+
import packaging.specifiers
10+
import packaging.version
11+
from typing_extensions import ParamSpec
12+
13+
P = ParamSpec('P')
14+
R = TypeVar('R')
15+
16+
17+
def requires_package(
18+
package_name: str, requirements: str, method_name: str
19+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
20+
required_versions = packaging.specifiers.SpecifierSet(requirements)
21+
22+
def decorator(method: Callable[P, R]) -> Callable[P, R]:
23+
@functools.wraps(method)
24+
def function(*args: P.args, **kwargs: P.kwargs) -> R:
25+
error: str | None = None
26+
try:
27+
package_version = importlib.metadata.version(package_name)
28+
installed_version = packaging.version.parse(package_version)
29+
30+
if installed_version not in required_versions:
31+
error = f', but you have "{installed_version}" installed'
32+
except importlib.metadata.PackageNotFoundError:
33+
error = ", but you don't have any version installed"
34+
if error is not None:
35+
e = RuntimeError(
36+
f'"{method_name}" requires package "{package_name}{requirements}" to be installed{error}'
37+
)
38+
current_frame = inspect.currentframe()
39+
new_tb: types.TracebackType | None = None
40+
if current_frame and current_frame.f_back:
41+
caller_frame = current_frame.f_back
42+
new_tb = types.TracebackType(
43+
tb_next=None,
44+
tb_frame=caller_frame,
45+
tb_lasti=caller_frame.f_lasti,
46+
tb_lineno=caller_frame.f_lineno,
47+
)
48+
raise e.with_traceback(new_tb)
49+
50+
return method(*args, **kwargs)
51+
return function
52+
return decorator

tests/utils/test_packages.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
import importlib.metadata
4+
5+
import pytest
6+
7+
from yandex_cloud_ml_sdk._utils.packages import requires_package
8+
9+
10+
@requires_package('somerandomname', '>1', 'mysupermethod')
11+
def mysupermethod():
12+
pass
13+
14+
15+
@requires_package('yandexcloud', '<0.100', 'mysupermethod')
16+
def mysupermethod2():
17+
pass
18+
19+
20+
@requires_package('yandexcloud', '>0.100', 'mysupermethod')
21+
def mysupermethod3():
22+
pass
23+
24+
25+
def test_requires_package():
26+
with pytest.raises(RuntimeError) as exc_info:
27+
mysupermethod()
28+
assert exc_info.value.args[0] == (
29+
'"mysupermethod" requires package "somerandomname>1"'
30+
' to be installed, but you don\'t have any version installed'
31+
)
32+
assert str(exc_info.traceback[0].path) == __file__
33+
assert str(exc_info.traceback[-1].path) == __file__
34+
35+
installed_version = importlib.metadata.version('yandexcloud')
36+
with pytest.raises(
37+
RuntimeError,
38+
match=(
39+
'"mysupermethod" requires package "yandexcloud<0.100" to be installed, '
40+
f'but you have "{installed_version}" installed'
41+
)
42+
):
43+
mysupermethod2()
44+
45+
mysupermethod3()

0 commit comments

Comments
 (0)