Skip to content

Commit a9c43f9

Browse files
update back to use import optional dependencies
1 parent 74c77f0 commit a9c43f9

10 files changed

+69
-71
lines changed

griptape/drivers/assistant/openai_assistant_driver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from griptape.artifacts import BaseArtifact, TextArtifact
99
from griptape.drivers.assistant import BaseAssistantDriver
1010
from griptape.events import EventBus, TextChunkEvent
11+
from griptape.utils import import_optional_dependency
1112
from griptape.utils.decorators import lazy_property
1213

1314
if TYPE_CHECKING:
@@ -22,16 +23,16 @@ class OpenAiAssistantDriver(BaseAssistantDriver):
2223
@staticmethod
2324
def _create_event_handler_class() -> type[AssistantEventHandler]: # pyright: ignore[reportInvalidTypeForm]
2425
"""Lazily import and create EventHandler class."""
25-
from openai import AssistantEventHandler
26+
AssistantEventHandler = import_optional_dependency("openai").AssistantEventHandler
2627

2728
class EventHandler(AssistantEventHandler):
2829
@override
29-
def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: # pyright: ignore[reportUndefinedVariable]
30+
def on_text_delta(self, delta, snapshot) -> None: # pyright: ignore[reportUndefinedVariable]
3031
if delta.value is not None:
3132
EventBus.publish_event(TextChunkEvent(token=delta.value))
3233

3334
@override
34-
def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: # pyright: ignore[reportUndefinedVariable]
35+
def on_tool_call_delta(self, delta, snapshot) -> None: # pyright: ignore[reportUndefinedVariable]
3536
if delta.type == "code_interpreter" and delta.code_interpreter is not None:
3637
if delta.code_interpreter.input:
3738
EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
@@ -55,14 +56,13 @@ def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
5556
)
5657
auto_create_thread: bool = field(default=True, kw_only=True)
5758

58-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
59+
_client: Optional[openai.OpenAI] = field(
5960
default=None, kw_only=True, alias="client", metadata={"serializable": False}
6061
)
6162

6263
@lazy_property()
63-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
64-
import openai
65-
64+
def client(self) -> openai.OpenAI:
65+
openai = import_optional_dependency("openai")
6666
return openai.OpenAI(
6767
base_url=self.base_url,
6868
api_key=self.api_key,

griptape/drivers/audio_transcription/openai_audio_transcription_driver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from griptape.artifacts import AudioArtifact, TextArtifact
99
from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver
10+
from griptape.utils import import_optional_dependency
1011
from griptape.utils.decorators import lazy_property
1112

1213
if TYPE_CHECKING:
@@ -15,19 +16,20 @@
1516

1617
@define
1718
class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
19+
# These defaults were changed from openai.api_type, openai.api_version, and openai.organization
20+
# to None because those module-level attributes don't exist in OpenAI SDK v1.0+
1821
api_type: Optional[str] = field(default=None, kw_only=True)
1922
api_version: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
2023
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
2124
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
2225
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
23-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
26+
_client: Optional[openai.OpenAI] = field(
2427
default=None, kw_only=True, alias="client", metadata={"serializable": False}
2528
)
2629

2730
@lazy_property()
28-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
29-
import openai
30-
31+
def client(self) -> openai.OpenAI:
32+
openai = import_optional_dependency("openai")
3133
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
3234

3335
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:

griptape/drivers/embedding/azure_openai_embedding_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver
88
from griptape.tokenizers import OpenAiTokenizer
9+
from griptape.utils import import_optional_dependency
910
from griptape.utils.decorators import lazy_property
1011

1112
if TYPE_CHECKING:
@@ -43,14 +44,13 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
4344
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
4445
kw_only=True,
4546
)
46-
_client: Optional[openai.AzureOpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
47+
_client: Optional[openai.AzureOpenAI] = field(
4748
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4849
)
4950

5051
@lazy_property()
51-
def client(self) -> openai.AzureOpenAI: # pyright: ignore[reportInvalidTypeForm]
52-
import openai
53-
52+
def client(self) -> openai.AzureOpenAI:
53+
openai = import_optional_dependency("openai")
5454
return openai.AzureOpenAI(
5555
organization=self.organization,
5656
api_key=self.api_key,

griptape/drivers/embedding/openai_embedding_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from griptape.drivers.embedding import BaseEmbeddingDriver
88
from griptape.tokenizers import OpenAiTokenizer
9+
from griptape.utils import import_optional_dependency
910
from griptape.utils.decorators import lazy_property
1011

1112
if TYPE_CHECKING:
@@ -40,14 +41,13 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
4041
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
4142
kw_only=True,
4243
)
43-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
44+
_client: Optional[openai.OpenAI] = field(
4445
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4546
)
4647

4748
@lazy_property()
48-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
49-
import openai
50-
49+
def client(self) -> openai.OpenAI:
50+
openai = import_optional_dependency("openai")
5151
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
5252

5353
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:

griptape/drivers/image_generation/azure_openai_image_generation_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from attrs import Factory, define, field
66

77
from griptape.drivers.image_generation.openai import OpenAiImageGenerationDriver
8+
from griptape.utils import import_optional_dependency
89
from griptape.utils.decorators import lazy_property
910

1011
if TYPE_CHECKING:
@@ -37,14 +38,13 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
3738
metadata={"serializable": False},
3839
)
3940
api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
40-
_client: Optional[openai.AzureOpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
41+
_client: Optional[openai.AzureOpenAI] = field(
4142
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4243
)
4344

4445
@lazy_property()
45-
def client(self) -> openai.AzureOpenAI: # pyright: ignore[reportInvalidTypeForm]
46-
import openai
47-
46+
def client(self) -> openai.AzureOpenAI:
47+
openai = import_optional_dependency("openai")
4848
return openai.AzureOpenAI(
4949
organization=self.organization,
5050
api_key=self.api_key,

griptape/drivers/image_generation/openai_image_generation_driver.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from attrs import Factory, define, field, fields_dict
77

88
from griptape.drivers.image_generation import BaseImageGenerationDriver
9+
from griptape.utils import import_optional_dependency
910
from griptape.utils.decorators import lazy_property
1011

1112
if TYPE_CHECKING:
@@ -40,6 +41,8 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
4041
output_format: Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'.
4142
"""
4243

44+
# These defaults were changed from openai.api_type, openai.api_version, and openai.organization
45+
# to None because those module-level attributes don't exist in OpenAI SDK v1.0+
4346
api_type: Optional[str] = field(default=None, kw_only=True)
4447
api_version: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
4548
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
@@ -83,27 +86,23 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
8386
kw_only=True,
8487
metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
8588
)
86-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
89+
_client: Optional[openai.OpenAI] = field(
8790
default=None, kw_only=True, alias="client", metadata={"serializable": False}
8891
)
8992
ignored_exception_types: tuple[type[Exception], ...] = field(
90-
default=Factory(lambda self: self._default_ignored_exception_types(), takes_self=True),
93+
default=Factory(
94+
lambda: (
95+
import_optional_dependency("openai").BadRequestError,
96+
import_optional_dependency("openai").AuthenticationError,
97+
import_optional_dependency("openai").PermissionDeniedError,
98+
import_optional_dependency("openai").NotFoundError,
99+
import_optional_dependency("openai").ConflictError,
100+
import_optional_dependency("openai").UnprocessableEntityError,
101+
),
102+
),
91103
kw_only=True,
92104
)
93105

94-
def _default_ignored_exception_types(self) -> tuple[type[Exception], ...]:
95-
"""Lazily import openai and return default exception types."""
96-
import openai
97-
98-
return (
99-
openai.BadRequestError,
100-
openai.AuthenticationError,
101-
openai.PermissionDeniedError,
102-
openai.NotFoundError,
103-
openai.ConflictError,
104-
openai.UnprocessableEntityError,
105-
)
106-
107106
@image_size.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
108107
def validate_image_size(self, attribute: str, value: str | None) -> None:
109108
"""Validates the image size based on the model.
@@ -129,9 +128,8 @@ def validate_image_size(self, attribute: str, value: str | None) -> None:
129128
raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}")
130129

131130
@lazy_property()
132-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
133-
import openai
134-
131+
def client(self) -> openai.OpenAI:
132+
openai = import_optional_dependency("openai")
135133
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
136134

137135
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:

griptape/drivers/prompt/azure_openai_chat_prompt_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from attrs import Factory, define, field
66

77
from griptape.drivers.prompt.openai import OpenAiChatPromptDriver
8+
from griptape.utils import import_optional_dependency
89
from griptape.utils.decorators import lazy_property
910

1011
if TYPE_CHECKING:
@@ -39,14 +40,13 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
3940
metadata={"serializable": False},
4041
)
4142
api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
42-
_client: Optional[openai.AzureOpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
43+
_client: Optional[openai.AzureOpenAI] = field(
4344
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4445
)
4546

4647
@lazy_property()
47-
def client(self) -> openai.AzureOpenAI: # pyright: ignore[reportInvalidTypeForm]
48-
import openai
49-
48+
def client(self) -> openai.AzureOpenAI:
49+
openai = import_optional_dependency("openai")
5050
return openai.AzureOpenAI(
5151
organization=self.organization,
5252
api_key=self.api_key,

griptape/drivers/prompt/openai_chat_prompt_driver.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from griptape.configs.defaults_config import Defaults
3131
from griptape.drivers.prompt import BasePromptDriver
3232
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
33+
from griptape.utils import import_optional_dependency
3334
from griptape.utils.decorators import lazy_property
3435

3536
if TYPE_CHECKING:
@@ -92,34 +93,29 @@ class OpenAiChatPromptDriver(BasePromptDriver):
9293
)
9394
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
9495
ignored_exception_types: tuple[type[Exception], ...] = field(
95-
default=Factory(lambda self: self._default_ignored_exception_types(), takes_self=True),
96+
default=Factory(
97+
lambda: (
98+
import_optional_dependency("openai").BadRequestError,
99+
import_optional_dependency("openai").AuthenticationError,
100+
import_optional_dependency("openai").PermissionDeniedError,
101+
import_optional_dependency("openai").NotFoundError,
102+
import_optional_dependency("openai").ConflictError,
103+
import_optional_dependency("openai").UnprocessableEntityError,
104+
),
105+
),
96106
kw_only=True,
97107
)
98108
modalities: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True})
99109
audio: dict = field(
100110
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
101111
)
102-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
112+
_client: Optional[openai.OpenAI] = field(
103113
default=None, kw_only=True, alias="client", metadata={"serializable": False}
104114
)
105115

106-
def _default_ignored_exception_types(self) -> tuple[type[Exception], ...]:
107-
"""Lazily import openai and return default exception types."""
108-
import openai
109-
110-
return (
111-
openai.BadRequestError,
112-
openai.AuthenticationError,
113-
openai.PermissionDeniedError,
114-
openai.NotFoundError,
115-
openai.ConflictError,
116-
openai.UnprocessableEntityError,
117-
)
118-
119116
@lazy_property()
120-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
121-
import openai
122-
117+
def client(self) -> openai.OpenAI:
118+
openai = import_optional_dependency("openai")
123119
return openai.OpenAI(
124120
base_url=self.base_url,
125121
api_key=self.api_key,

griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from attrs import Factory, define, field
66

77
from griptape.drivers.text_to_speech.openai import OpenAiTextToSpeechDriver
8+
from griptape.utils import import_optional_dependency
89
from griptape.utils.decorators import lazy_property
910

1011
if TYPE_CHECKING:
@@ -38,14 +39,13 @@ class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver):
3839
metadata={"serializable": False},
3940
)
4041
api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True})
41-
_client: Optional[openai.AzureOpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
42+
_client: Optional[openai.AzureOpenAI] = field(
4243
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4344
)
4445

4546
@lazy_property()
46-
def client(self) -> openai.AzureOpenAI: # pyright: ignore[reportInvalidTypeForm]
47-
import openai
48-
47+
def client(self) -> openai.AzureOpenAI:
48+
openai = import_optional_dependency("openai")
4949
return openai.AzureOpenAI(
5050
organization=self.organization,
5151
api_key=self.api_key,

griptape/drivers/text_to_speech/openai_text_to_speech_driver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from griptape.artifacts.audio_artifact import AudioArtifact
88
from griptape.drivers.text_to_speech import BaseTextToSpeechDriver
9+
from griptape.utils import import_optional_dependency
910
from griptape.utils.decorators import lazy_property
1011

1112
if TYPE_CHECKING:
@@ -21,19 +22,20 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver):
2122
metadata={"serializable": True},
2223
)
2324
format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True})
25+
# These defaults were changed from openai.api_type, openai.api_version, and openai.organization
26+
# to None because those module-level attributes don't exist in OpenAI SDK v1.0+
2427
api_type: Optional[str] = field(default=None, kw_only=True)
2528
api_version: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
2629
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
2730
api_key: Optional[str] = field(default=None, kw_only=True)
2831
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
29-
_client: Optional[openai.OpenAI] = field( # pyright: ignore[reportInvalidTypeForm]
32+
_client: Optional[openai.OpenAI] = field(
3033
default=None, kw_only=True, alias="client", metadata={"serializable": False}
3134
)
3235

3336
@lazy_property()
34-
def client(self) -> openai.OpenAI: # pyright: ignore[reportInvalidTypeForm]
35-
import openai
36-
37+
def client(self) -> openai.OpenAI:
38+
openai = import_optional_dependency("openai")
3739
return openai.OpenAI(
3840
api_key=self.api_key,
3941
base_url=self.base_url,

0 commit comments

Comments
 (0)