Skip to content

Commit 89938ee

Browse files
update for all driver discoverability
1 parent 4407f84 commit 89938ee

File tree

6 files changed

+331
-115
lines changed

6 files changed

+331
-115
lines changed

griptape/common/_lazy_loader.py

Lines changed: 6 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import importlib
66
import importlib.util
77
import pkgutil
8-
from pathlib import Path
98
from typing import Optional
109

1110
# Driver-specific mapping: class suffix -> driver type directory
@@ -50,10 +49,12 @@ def find_class_module(module_base_path: str, class_name: str, file_suffix: str =
5049
"""
5150
try:
5251
base_module = importlib.import_module(module_base_path)
53-
base_dir = Path(base_module.__file__).parent
52+
if base_module.__file__ is None:
53+
return None
54+
base_dir = str(base_module.__file__).rsplit("/", 1)[0]
5455

5556
# Walk through all modules in this directory
56-
for module_info in pkgutil.walk_packages([str(base_dir)], prefix=f"{module_base_path}."):
57+
for module_info in pkgutil.walk_packages([base_dir], prefix=f"{module_base_path}."):
5758
# Skip if looking for specific suffix and module doesn't match
5859
if file_suffix and not module_info.name.endswith(file_suffix):
5960
continue
@@ -103,10 +104,10 @@ def find_driver_module(class_name: str) -> Optional[str]:
103104
driver_type_module = importlib.import_module(driver_type_path)
104105
if driver_type_module.__file__ is None:
105106
return None
106-
driver_type_dir = Path(driver_type_module.__file__).parent
107+
driver_type_dir = str(driver_type_module.__file__).rsplit("/", 1)[0]
107108

108109
# Walk through all modules in this driver type directory
109-
for module_info in pkgutil.walk_packages([str(driver_type_dir)], prefix=f"{driver_type_path}."):
110+
for module_info in pkgutil.walk_packages([driver_type_dir], prefix=f"{driver_type_path}."):
110111
try:
111112
spec = importlib.util.find_spec(module_info.name)
112113
if spec is not None:
@@ -120,88 +121,3 @@ def find_driver_module(class_name: str) -> Optional[str]:
120121
pass
121122

122123
return None
123-
124-
125-
def discover_all_classes(module_base_path: str, file_suffix: str = "") -> list[str]:
126-
"""Scan filesystem to discover all available classes (for __dir__).
127-
128-
Args:
129-
module_base_path: Base module path (e.g., "griptape.structures")
130-
file_suffix: File suffix to look for (e.g., ".py" matches all, "_task.py" matches tasks)
131-
132-
Returns:
133-
List of all class names found in the directory
134-
"""
135-
try:
136-
base_module = importlib.import_module(module_base_path)
137-
if base_module.__file__ is None:
138-
return []
139-
base_dir = Path(base_module.__file__).parent
140-
except Exception:
141-
return []
142-
143-
class_names = []
144-
145-
for file in base_dir.glob("*.py"):
146-
if file.name.startswith("_") or file.name.startswith("base_"):
147-
continue
148-
if file_suffix and not file.stem.endswith(file_suffix.replace(".py", "")):
149-
continue
150-
151-
# Convert filename to class name (snake_case -> PascalCase)
152-
class_name = _snake_to_pascal_case(file.stem)
153-
class_names.append(class_name)
154-
155-
# Also check subdirectories for tools
156-
for subdir in base_dir.iterdir():
157-
if subdir.is_dir() and not subdir.name.startswith("_"):
158-
tool_file = subdir / "tool.py"
159-
if tool_file.exists():
160-
class_name = _snake_to_pascal_case(subdir.name) + "Tool"
161-
class_names.append(class_name)
162-
163-
return class_names
164-
165-
166-
def discover_all_drivers() -> list[str]:
167-
"""Scan filesystem to discover all available drivers (for __dir__).
168-
169-
This is a specialized version for drivers that scans all driver type subdirectories.
170-
171-
Returns:
172-
List of all driver class names found in the drivers directory
173-
"""
174-
try:
175-
drivers_module = importlib.import_module("griptape.drivers")
176-
if drivers_module.__file__ is None:
177-
return []
178-
drivers_path = Path(drivers_module.__file__).parent
179-
except Exception:
180-
return []
181-
182-
driver_classes = []
183-
184-
for driver_type_dir in drivers_path.iterdir():
185-
if not driver_type_dir.is_dir() or driver_type_dir.name.startswith("_"):
186-
continue
187-
188-
for file in driver_type_dir.glob("*_driver.py"):
189-
if file.name.startswith("base_"):
190-
continue
191-
# Convert filename to class name
192-
class_name = _snake_to_pascal_case(file.stem)
193-
driver_classes.append(class_name)
194-
195-
return driver_classes
196-
197-
198-
def _snake_to_pascal_case(name: str) -> str:
199-
"""Convert snake_case to PascalCase.
200-
201-
Args:
202-
name: snake_case string (e.g., "prompt_task")
203-
204-
Returns:
205-
PascalCase string (e.g., "PromptTask")
206-
"""
207-
return "".join(word.capitalize() for word in name.split("_"))

griptape/drivers/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from griptape.utils.deprecation import DeprecationModuleWrapper
7-
from griptape.common._lazy_loader import find_driver_module, discover_all_drivers
7+
from griptape.common._lazy_loader import find_driver_module
88

99
# Import base classes eagerly (they're always needed for type checking and inheritance)
1010
from .prompt import BasePromptDriver
@@ -67,15 +67,13 @@ def __dir__() -> list[str]:
6767
"""Support dir() and IDE autocomplete by listing all available drivers.
6868
6969
Returns:
70-
List of all available names in this module (base classes + discovered drivers)
70+
List of all available names in this module (from __all__)
7171
"""
72-
# Combine eagerly loaded base classes with dynamically discovered drivers
73-
base_names = [name for name in globals() if not name.startswith("_")]
74-
discovered = discover_all_drivers()
75-
return sorted(set(base_names + discovered))
72+
# Return __all__ which contains the complete and correctly-capitalized list
73+
return __all__
7674

7775

78-
__all__ = [
76+
__all__ = [ # pyright: ignore[reportUnsupportedDunderAll]
7977
# Base classes (eagerly loaded)
8078
"BasePromptDriver",
8179
"BaseConversationMemoryDriver",

griptape/structures/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
from typing import Any
33

4-
from griptape.common._lazy_loader import find_class_module, discover_all_classes
4+
from griptape.common._lazy_loader import find_class_module
55

66

77
def __getattr__(name: str) -> Any:
@@ -37,11 +37,10 @@ def __dir__() -> list[str]:
3737
"""Support dir() and IDE autocomplete.
3838
3939
Returns:
40-
List of all available structure names
40+
List of all available structure names (from __all__)
4141
"""
42-
base_names = [name for name in globals() if not name.startswith("_")]
43-
discovered = discover_all_classes("griptape.structures")
44-
return sorted(set(base_names + discovered))
42+
# Return __all__ which contains the complete list
43+
return __all__
4544

4645

47-
__all__ = ["Agent", "Pipeline", "Structure", "Workflow"]
46+
__all__ = ["Agent", "Pipeline", "Structure", "Workflow"] # pyright: ignore[reportUnsupportedDunderAll]

griptape/tasks/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
from typing import Any
33

4-
from griptape.common._lazy_loader import find_class_module, discover_all_classes
4+
from griptape.common._lazy_loader import find_class_module
55

66

77
def __getattr__(name: str) -> Any:
@@ -37,14 +37,13 @@ def __dir__() -> list[str]:
3737
"""Support dir() and IDE autocomplete.
3838
3939
Returns:
40-
List of all available task names
40+
List of all available task names (from __all__)
4141
"""
42-
base_names = [name for name in globals() if not name.startswith("_")]
43-
discovered = discover_all_classes("griptape.tasks")
44-
return sorted(set(base_names + discovered))
42+
# Return __all__ which contains the complete list
43+
return __all__
4544

4645

47-
__all__ = [
46+
__all__ = [ # pyright: ignore[reportUnsupportedDunderAll]
4847
"ActionsSubtask",
4948
"OutputSchemaValidationSubtask",
5049
"AssistantTask",

griptape/tools/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
from typing import Any
33

4-
from griptape.common._lazy_loader import find_class_module, discover_all_classes
4+
from griptape.common._lazy_loader import find_class_module
55

66

77
def __getattr__(name: str) -> Any:
@@ -37,14 +37,13 @@ def __dir__() -> list[str]:
3737
"""Support dir() and IDE autocomplete.
3838
3939
Returns:
40-
List of all available tool names
40+
List of all available tool names (from __all__)
4141
"""
42-
base_names = [name for name in globals() if not name.startswith("_")]
43-
discovered = discover_all_classes("griptape.tools")
44-
return sorted(set(base_names + discovered))
42+
# Return __all__ which contains the complete list
43+
return __all__
4544

4645

47-
__all__ = [
46+
__all__ = [ # pyright: ignore[reportUnsupportedDunderAll]
4847
"AudioTranscriptionTool",
4948
"BaseTool",
5049
"BaseImageGenerationTool",

0 commit comments

Comments
 (0)