diff --git a/pyproject.toml b/pyproject.toml index dce6a0870e1..0b0164862a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "packaging", "pluggy>=1.5,<2", "tomli>=1; python_version<'3.11'", + "typing-extensions; python_version<'3.10'", ] optional-dependencies.dev = [ "argcomplete", diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index 8b79dbcb932..454a1e52a6c 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -75,6 +75,13 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec + from typing_extensions import TypeAlias +else: + from typing import ParamSpec + from typing import TypeAlias + if TYPE_CHECKING: from _pytest.python import CallSpec2 @@ -84,14 +91,20 @@ # The value of the fixture -- return/yield of the fixture function (type variable). FixtureValue = TypeVar("FixtureValue") -# The type of the fixture function (type variable). -FixtureFunction = TypeVar("FixtureFunction", bound=Callable[..., object]) -# The type of a fixture function (type alias generic in fixture value). -_FixtureFunc = Union[ - Callable[..., FixtureValue], Callable[..., Generator[FixtureValue]] + +# The parameters that a fixture function receives. +FixtureParams = ParamSpec("FixtureParams") + +# A dict of fixture name -> its FixtureDef. +FixtureDefDict: TypeAlias = dict[str, "FixtureDef[Any]"] + +# The type of fixture function (type alias generic in fixture params and value). +_FixtureFunc: TypeAlias = Union[ + Callable[FixtureParams, FixtureValue], + Callable[FixtureParams, Generator[FixtureValue, None, None]], ] # The type of FixtureDef.cached_result (type alias generic in fixture value). -_FixtureCachedResult = Union[ +_FixtureCachedResult: TypeAlias = Union[ tuple[ # The result. FixtureValue, @@ -360,7 +373,7 @@ def __init__( pyfuncitem: Function, fixturename: str | None, arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]], - fixture_defs: dict[str, FixtureDef[Any]], + fixture_defs: FixtureDefDict, *, _ispytest: bool = False, ) -> None: @@ -886,7 +899,9 @@ def toterminal(self, tw: TerminalWriter) -> None: def call_fixture_func( - fixturefunc: _FixtureFunc[FixtureValue], request: FixtureRequest, kwargs + fixturefunc: _FixtureFunc[FixtureParams, FixtureValue], + request: FixtureRequest, + kwargs: FixtureParams.kwargs, ) -> FixtureValue: if inspect.isgeneratorfunction(fixturefunc): fixturefunc = cast(Callable[..., Generator[FixtureValue]], fixturefunc) @@ -957,7 +972,7 @@ def __init__( config: Config, baseid: str | None, argname: str, - func: _FixtureFunc[FixtureValue], + func: _FixtureFunc[Any, FixtureValue], scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] | None, params: Sequence[object] | None, ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None, @@ -1113,7 +1128,7 @@ def __repr__(self) -> str: def resolve_fixture_function( fixturedef: FixtureDef[FixtureValue], request: FixtureRequest -) -> _FixtureFunc[FixtureValue]: +) -> _FixtureFunc[Any, FixtureValue]: """Get the actual callable that can be called to obtain the fixture value.""" fixturefunc = fixturedef.func @@ -1192,7 +1207,9 @@ class FixtureFunctionMarker: def __post_init__(self, _ispytest: bool) -> None: check_ispytest(_ispytest) - def __call__(self, function: FixtureFunction) -> FixtureFunctionDefinition: + def __call__( + self, function: Callable[FixtureParams, FixtureValue] + ) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]: if inspect.isclass(function): raise ValueError("class fixtures not supported (maybe in the future)") @@ -1219,12 +1236,10 @@ def __call__(self, function: FixtureFunction) -> FixtureFunctionDefinition: return fixture_definition -# TODO: paramspec/return type annotation tracking and storing -class FixtureFunctionDefinition: +class FixtureFunctionDefinition(Generic[FixtureParams, FixtureValue]): def __init__( self, - *, - function: Callable[..., Any], + function: Callable[FixtureParams, FixtureValue], fixture_function_marker: FixtureFunctionMarker, instance: object | None = None, _ispytest: bool = False, @@ -1237,7 +1252,7 @@ def __init__( self._fixture_function_marker = fixture_function_marker if instance is not None: self._fixture_function = cast( - Callable[..., Any], function.__get__(instance) + Callable[FixtureParams, FixtureValue], function.__get__(instance) ) else: self._fixture_function = function @@ -1246,7 +1261,9 @@ def __init__( def __repr__(self) -> str: return f"" - def __get__(self, instance, owner=None): + def __get__( + self, instance: object, owner: type | None = None + ) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]: """Behave like a method if the function it was applied to was a method.""" return FixtureFunctionDefinition( function=self._fixture_function, @@ -1270,14 +1287,14 @@ def _get_wrapped_function(self) -> Callable[..., Any]: @overload def fixture( - fixture_function: Callable[..., object], + fixture_function: Callable[FixtureParams, FixtureValue], *, scope: _ScopeName | Callable[[str, Config], _ScopeName] = ..., params: Iterable[object] | None = ..., autouse: bool = ..., ids: Sequence[object | None] | Callable[[Any], object | None] | None = ..., name: str | None = ..., -) -> FixtureFunctionDefinition: ... +) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]: ... @overload @@ -1293,14 +1310,14 @@ def fixture( def fixture( - fixture_function: FixtureFunction | None = None, + fixture_function: Callable[FixtureParams, FixtureValue] | None = None, *, scope: _ScopeName | Callable[[str, Config], _ScopeName] = "function", params: Iterable[object] | None = None, autouse: bool = False, ids: Sequence[object | None] | Callable[[Any], object | None] | None = None, name: str | None = None, -) -> FixtureFunctionMarker | FixtureFunctionDefinition: +) -> FixtureFunctionMarker | FixtureFunctionDefinition[FixtureParams, FixtureValue]: """Decorator to mark a fixture factory function. This decorator can be used, with or without parameters, to define a @@ -1688,7 +1705,7 @@ def _register_fixture( self, *, name: str, - func: _FixtureFunc[object], + func: _FixtureFunc[Any, object], nodeid: str | None, scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] = "function", params: Sequence[object] | None = None, diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 85e3cb0ae71..f17a696799a 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -51,6 +51,7 @@ from _pytest.config.argparsing import Parser from _pytest.deprecated import check_ispytest from _pytest.fixtures import FixtureDef +from _pytest.fixtures import FixtureDefDict from _pytest.fixtures import FixtureRequest from _pytest.fixtures import FuncFixtureInfo from _pytest.fixtures import get_scope_node @@ -1085,7 +1086,7 @@ def get_direct_param_fixture_func(request: FixtureRequest) -> Any: # Used for storing pseudo fixturedefs for direct parametrization. -name2pseudofixturedef_key = StashKey[dict[str, FixtureDef[Any]]]() +name2pseudofixturedef_key = StashKey[FixtureDefDict]() @final @@ -1271,7 +1272,7 @@ def parametrize( if node is None: name2pseudofixturedef = None else: - default: dict[str, FixtureDef[Any]] = {} + default: FixtureDefDict = {} name2pseudofixturedef = node.stash.setdefault( name2pseudofixturedef_key, default )