-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathprovider_matrix.py
More file actions
105 lines (90 loc) · 3.42 KB
/
Copy pathprovider_matrix.py
File metadata and controls
105 lines (90 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Shared provider capability matrix for v2 tests."""
from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
from instructor import Provider
from instructor.v2.core.provider_specs import PROVIDER_SPECS
from instructor.v2.core.registry import mode_registry
TEST_PROVIDER_SPECS = {
provider: spec
for provider, spec in PROVIDER_SPECS.items()
if spec.handler_module is not None and spec.from_function is not None
}
PROVIDER_HANDLER_MODES = {
provider: spec.supported_modes for provider, spec in PROVIDER_SPECS.items()
}
PARTIAL_STREAM_CASES = tuple(
(provider, mode)
for provider, spec in TEST_PROVIDER_SPECS.items()
for mode in spec.capabilities.partial_stream_modes
)
ITERABLE_STREAM_CASES = tuple(
(provider, mode)
for provider, spec in TEST_PROVIDER_SPECS.items()
for mode in spec.capabilities.iterable_stream_modes
)
TYPED_MULTIMODAL_PROVIDERS = tuple(
provider
for provider, spec in TEST_PROVIDER_SPECS.items()
if spec.capabilities.multimodal_inputs
)
TYPED_MULTIMODAL_CASES = tuple(
(provider, media_type)
for provider, spec in TEST_PROVIDER_SPECS.items()
for media_type in spec.capabilities.multimodal_inputs
)
EXPLICIT_PARALLEL_PROVIDERS = tuple(
provider
for provider, spec in TEST_PROVIDER_SPECS.items()
if spec.capabilities.explicit_parallel_tools
)
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_HANDLERS_LOADED: set[Provider] = set()
def handler_module_path(provider: Provider) -> Path | None:
"""Return the registered handler implementation path for a provider."""
module = PROVIDER_SPECS[provider].handler_module
if module is None:
return None
return _PROJECT_ROOT / f"{module.replace('.', '/')}.py"
def _is_expected_missing_dependency(provider: Provider, exc: ImportError) -> bool:
sdk_module = PROVIDER_SPECS[provider].sdk_module
if sdk_module is None:
return False
expected_root = sdk_module.split(".")[0]
missing_name = getattr(exc, "name", None)
if missing_name:
return missing_name.split(".")[0] == expected_root
return f"No module named '{expected_root}'" in str(exc)
def ensure_handlers_loaded(
provider: Provider, *, skip_missing_dependency: bool = False
) -> None:
"""Load handlers once from the manifest path so registration tests share setup."""
if provider in _HANDLERS_LOADED:
return
provider_modes = PROVIDER_HANDLER_MODES.get(provider, ())
if provider_modes and all(
mode_registry.is_registered(provider, mode) for mode in provider_modes
):
_HANDLERS_LOADED.add(provider)
return
path = handler_module_path(provider)
if path is None or not path.exists():
return
spec = importlib.util.spec_from_file_location(
f"tests.v2.handlers_{provider.value}",
path,
)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load handler module for {provider}")
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except (ImportError, ModuleNotFoundError) as exc:
if skip_missing_dependency and _is_expected_missing_dependency(provider, exc):
pytest.skip(
f"{provider.value} handlers require optional dependency " # ty: ignore[too-many-positional-arguments]
f"{PROVIDER_SPECS[provider].sdk_module}"
)
raise
_HANDLERS_LOADED.add(provider)