|
14 | 14 | from rez.utils.logging_ import print_debug, print_warning |
15 | 15 | from rez.exceptions import RezPluginError |
16 | 16 | from zipimport import zipimporter |
17 | | -from typing import overload, Any, TypeVar |
| 17 | +from typing import overload, Any, TypeVar, TYPE_CHECKING |
18 | 18 | import pkgutil |
19 | 19 | import os.path |
20 | 20 | import sys |
21 | 21 | import types |
22 | 22 |
|
| 23 | +if TYPE_CHECKING: |
| 24 | + from typing import Literal # not available in typing module until 3.8 |
| 25 | + from rez.shells import Shell |
| 26 | + from rez.release_vcs import ReleaseVCS |
| 27 | + from rez.release_hook import ReleaseHook |
| 28 | + from rez.build_process import BuildProcess |
| 29 | + from rez.build_system import BuildSystem |
| 30 | + from rez.package_repository import PackageRepository |
| 31 | + from rez.command import Command |
| 32 | + |
23 | 33 | T = TypeVar("T") |
24 | 34 |
|
25 | 35 |
|
@@ -362,18 +372,41 @@ def get_plugins(self, plugin_type: str) -> list[str]: |
362 | 372 | return list(self._get_plugin_type(plugin_type).plugin_classes.keys()) |
363 | 373 |
|
364 | 374 | @overload |
365 | | - def get_plugin_class(self, plugin_type: str, plugin_name: str) -> type: |
| 375 | + def get_plugin_class(self, plugin_type: Literal["shell"], plugin_name: str) -> type[Shell]: |
| 376 | + pass |
| 377 | + |
| 378 | + @overload |
| 379 | + def get_plugin_class(self, plugin_type: Literal["release_vcs"], plugin_name: str) -> type[ReleaseVCS]: |
| 380 | + pass |
| 381 | + |
| 382 | + @overload |
| 383 | + def get_plugin_class(self, plugin_type: Literal["release_hook"], plugin_name: str) -> type[ReleaseHook]: |
| 384 | + pass |
| 385 | + |
| 386 | + @overload |
| 387 | + def get_plugin_class(self, plugin_type: Literal["package_repository"], plugin_name: str) -> type[PackageRepository]: |
366 | 388 | pass |
367 | 389 |
|
368 | 390 | @overload |
369 | | - def get_plugin_class(self, plugin_type: str, plugin_name: str, expected_type: type[T]) -> type[T]: |
| 391 | + def get_plugin_class(self, plugin_type: Literal["build_system"], plugin_name: str) -> type[BuildSystem]: |
370 | 392 | pass |
371 | 393 |
|
372 | | - def get_plugin_class(self, plugin_type: str, plugin_name: str, expected_type: type | None = None) -> type: |
| 394 | + @overload |
| 395 | + def get_plugin_class(self, plugin_type: Literal["package_repository"], plugin_name: str) -> type[PackageRepository]: |
| 396 | + pass |
| 397 | + |
| 398 | + @overload |
| 399 | + def get_plugin_class(self, plugin_type: Literal["build_process"], plugin_name: str) -> type[BuildProcess]: |
| 400 | + pass |
| 401 | + |
| 402 | + @overload |
| 403 | + def get_plugin_class(self, plugin_type: Literal["command"], plugin_name: str) -> type[Command]: |
| 404 | + pass |
| 405 | + |
| 406 | + def get_plugin_class(self, plugin_type: str, plugin_name: str) -> type: |
373 | 407 | """Return the class registered under the given plugin name.""" |
374 | 408 | plugin = self._get_plugin_type(plugin_type) |
375 | | - cls = plugin.get_plugin_class(plugin_name) |
376 | | - return cls |
| 409 | + return plugin.get_plugin_class(plugin_name) |
377 | 410 |
|
378 | 411 | def get_plugin_module(self, plugin_type: str, plugin_name: str) -> types.ModuleType: |
379 | 412 | """Return the module defining the class registered under the given |
|
0 commit comments