Skip to content

Commit 215c2f5

Browse files
GWealecopybara-github
authored andcommitted
fix: Set LITELLM_MODE to PRODUCTION before importing LiteLLM
LiteLLM defaults to DEV mode, which automatically loads environment variables from a local `.env` file. This change sets LITELLM_MODE to PRODUCTION to prevent LiteLLM from implicitly loading `.env` files when used within ADK. Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 858723362
1 parent 135f763 commit 215c2f5

File tree

2 files changed

+205
-23
lines changed

2 files changed

+205
-23
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import base64
1818
import copy
19+
import importlib.util
1920
import json
2021
import logging
2122
import mimetypes
@@ -32,27 +33,20 @@
3233
from typing import Literal
3334
from typing import Optional
3435
from typing import Tuple
36+
from typing import TYPE_CHECKING
3537
from typing import TypedDict
3638
from typing import Union
3739
from urllib.parse import urlparse
3840
import uuid
3941
import warnings
4042

4143
from google.genai import types
42-
import litellm
43-
from litellm import acompletion
44-
from litellm import ChatCompletionAssistantMessage
45-
from litellm import ChatCompletionAssistantToolCall
46-
from litellm import ChatCompletionMessageToolCall
47-
from litellm import ChatCompletionSystemMessage
48-
from litellm import ChatCompletionToolMessage
49-
from litellm import ChatCompletionUserMessage
50-
from litellm import completion
51-
from litellm import CustomStreamWrapper
52-
from litellm import Function
53-
from litellm import Message
54-
from litellm import ModelResponse
55-
from litellm import OpenAIMessageContent
44+
45+
if not TYPE_CHECKING and importlib.util.find_spec("litellm") is None:
46+
raise ImportError(
47+
"LiteLLM support requires: pip install google-adk[extensions]"
48+
)
49+
5650
from pydantic import BaseModel
5751
from pydantic import Field
5852
from typing_extensions import override
@@ -61,8 +55,36 @@
6155
from .llm_request import LlmRequest
6256
from .llm_response import LlmResponse
6357

64-
# This will add functions to prompts if functions are provided.
65-
litellm.add_function_to_prompt = True
58+
if TYPE_CHECKING:
59+
import litellm
60+
from litellm import acompletion
61+
from litellm import ChatCompletionAssistantMessage
62+
from litellm import ChatCompletionAssistantToolCall
63+
from litellm import ChatCompletionMessageToolCall
64+
from litellm import ChatCompletionSystemMessage
65+
from litellm import ChatCompletionToolMessage
66+
from litellm import ChatCompletionUserMessage
67+
from litellm import completion
68+
from litellm import CustomStreamWrapper
69+
from litellm import Function
70+
from litellm import Message
71+
from litellm import ModelResponse
72+
from litellm import OpenAIMessageContent
73+
else:
74+
litellm = None
75+
acompletion = None
76+
ChatCompletionAssistantMessage = None
77+
ChatCompletionAssistantToolCall = None
78+
ChatCompletionMessageToolCall = None
79+
ChatCompletionSystemMessage = None
80+
ChatCompletionToolMessage = None
81+
ChatCompletionUserMessage = None
82+
completion = None
83+
CustomStreamWrapper = None
84+
Function = None
85+
Message = None
86+
ModelResponse = None
87+
OpenAIMessageContent = None
6688

6789
logger = logging.getLogger("google_adk." + __name__)
6890

@@ -109,6 +131,50 @@
109131
"before a response was recorded)."
110132
)
111133

134+
_LITELLM_IMPORTED = False
135+
_LITELLM_GLOBAL_SYMBOLS = (
136+
"ChatCompletionAssistantMessage",
137+
"ChatCompletionAssistantToolCall",
138+
"ChatCompletionMessageToolCall",
139+
"ChatCompletionSystemMessage",
140+
"ChatCompletionToolMessage",
141+
"ChatCompletionUserMessage",
142+
"CustomStreamWrapper",
143+
"Function",
144+
"Message",
145+
"ModelResponse",
146+
"OpenAIMessageContent",
147+
"acompletion",
148+
"completion",
149+
)
150+
151+
152+
def _ensure_litellm_imported() -> None:
153+
"""Imports LiteLLM with safe defaults.
154+
155+
LiteLLM defaults to DEV mode, which auto-loads a local `.env` at import time.
156+
ADK should not implicitly load `.env` just because LiteLLM is installed.
157+
158+
Users can opt into LiteLLM's default behavior by setting LITELLM_MODE=DEV.
159+
"""
160+
global _LITELLM_IMPORTED
161+
if _LITELLM_IMPORTED:
162+
return
163+
164+
# https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py#L80-L82
165+
os.environ.setdefault("LITELLM_MODE", "PRODUCTION")
166+
167+
import litellm as litellm_module
168+
169+
litellm_module.add_function_to_prompt = True
170+
171+
globals()["litellm"] = litellm_module
172+
for symbol in _LITELLM_GLOBAL_SYMBOLS:
173+
globals()[symbol] = getattr(litellm_module, symbol)
174+
175+
_redirect_litellm_loggers_to_stdout()
176+
_LITELLM_IMPORTED = True
177+
112178

113179
def _map_finish_reason(
114180
finish_reason: Any,
@@ -344,6 +410,7 @@ async def acompletion(
344410
Returns:
345411
The model response as a message.
346412
"""
413+
_ensure_litellm_imported()
347414

348415
return await acompletion(
349416
model=model,
@@ -367,6 +434,7 @@ def completion(
367434
Returns:
368435
The response from the model.
369436
"""
437+
_ensure_litellm_imported()
370438

371439
return completion(
372440
model=model,
@@ -513,6 +581,7 @@ async def _content_to_message_param(
513581
Returns:
514582
A litellm Message, a list of litellm Messages.
515583
"""
584+
_ensure_litellm_imported()
516585

517586
tool_messages: list[Message] = []
518587
non_tool_parts: list[types.Part] = []
@@ -622,6 +691,8 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
622691
if not messages:
623692
return messages
624693

694+
_ensure_litellm_imported()
695+
625696
healed_messages: List[Message] = []
626697
pending_tool_call_ids: List[str] = []
627698

@@ -691,6 +762,7 @@ async def _get_content(
691762
Returns:
692763
The litellm content.
693764
"""
765+
_ensure_litellm_imported()
694766

695767
parts_list = list(parts)
696768
if len(parts_list) == 1:
@@ -925,6 +997,7 @@ def _build_tool_call_from_json_dict(
925997
candidate: Any, *, index: int
926998
) -> Optional[ChatCompletionMessageToolCall]:
927999
"""Creates a tool call object from JSON content embedded in text."""
1000+
_ensure_litellm_imported()
9281001

9291002
if not isinstance(candidate, dict):
9301003
return None
@@ -972,11 +1045,12 @@ def _parse_tool_calls_from_text(
9721045
text_block: str,
9731046
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
9741047
"""Extracts inline JSON tool calls from LiteLLM text responses."""
975-
9761048
tool_calls = []
9771049
if not text_block:
9781050
return tool_calls, None
9791051

1052+
_ensure_litellm_imported()
1053+
9801054
remainder_segments = []
9811055
cursor = 0
9821056
text_length = len(text_block)
@@ -1014,7 +1088,6 @@ def _split_message_content_and_tool_calls(
10141088
message: Message,
10151089
) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]:
10161090
"""Returns message content and tool calls, parsing inline JSON when needed."""
1017-
10181091
existing_tool_calls = message.get("tool_calls") or []
10191092
normalized_tool_calls = (
10201093
list(existing_tool_calls) if existing_tool_calls else []
@@ -1180,6 +1253,7 @@ def _model_response_to_chunk(
11801253
Yields:
11811254
A tuple of text or function or usage metadata chunk and finish reason.
11821255
"""
1256+
_ensure_litellm_imported()
11831257

11841258
message = None
11851259
if response.get("choices", None):
@@ -1255,6 +1329,7 @@ def _model_response_to_generate_content_response(
12551329
Returns:
12561330
The LlmResponse.
12571331
"""
1332+
_ensure_litellm_imported()
12581333

12591334
message = None
12601335
finish_reason = None
@@ -1313,6 +1388,7 @@ def _message_to_generate_content_response(
13131388
Returns:
13141389
The LlmResponse.
13151390
"""
1391+
_ensure_litellm_imported()
13161392

13171393
parts: List[types.Part] = []
13181394
if not thought_parts:
@@ -1440,6 +1516,8 @@ async def _get_completion_inputs(
14401516
The litellm inputs (message list, tool dictionary, response format and
14411517
generation params).
14421518
"""
1519+
_ensure_litellm_imported()
1520+
14431521
# Determine provider for file handling
14441522
provider = _get_provider_from_model(model)
14451523

@@ -1665,11 +1743,6 @@ def _redirect_litellm_loggers_to_stdout() -> None:
16651743
handler.stream = sys.stdout
16661744

16671745

1668-
# Redirect LiteLLM loggers to stdout immediately after import to ensure
1669-
# INFO-level logs are not incorrectly treated as errors in cloud environments.
1670-
_redirect_litellm_loggers_to_stdout()
1671-
1672-
16731746
class LiteLlm(BaseLlm):
16741747
"""Wrapper around litellm.
16751748
@@ -1732,6 +1805,7 @@ async def generate_content_async(
17321805
Yields:
17331806
LlmResponse: The model response.
17341807
"""
1808+
_ensure_litellm_imported()
17351809

17361810
self._maybe_append_user_content(llm_request)
17371811
_append_fallback_user_content_if_missing(llm_request)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
import os
17+
import subprocess
18+
import sys
19+
20+
import pytest
21+
22+
23+
def _subprocess_env() -> dict[str, str]:
24+
env = dict(os.environ)
25+
src_path = os.path.join(os.getcwd(), "src")
26+
pythonpath = env.get("PYTHONPATH", "")
27+
env["PYTHONPATH"] = (
28+
f"{src_path}{os.pathsep}{pythonpath}" if pythonpath else src_path
29+
)
30+
return env
31+
32+
33+
def test_importing_models_does_not_import_litellm_or_set_mode():
34+
env = _subprocess_env()
35+
env.pop("LITELLM_MODE", None)
36+
37+
result = subprocess.run(
38+
[
39+
sys.executable,
40+
"-c",
41+
(
42+
"import os, sys\n"
43+
"import google.adk.models\n"
44+
"print('litellm' in sys.modules)\n"
45+
"print(os.environ.get('LITELLM_MODE'))\n"
46+
),
47+
],
48+
check=True,
49+
capture_output=True,
50+
text=True,
51+
env=env,
52+
)
53+
stdout_lines = result.stdout.strip().splitlines()
54+
assert stdout_lines == ["False", "None"]
55+
56+
57+
def test_ensure_litellm_imported_defaults_to_production():
58+
if importlib.util.find_spec("litellm") is None:
59+
pytest.skip("litellm is not installed")
60+
61+
env = _subprocess_env()
62+
env.pop("LITELLM_MODE", None)
63+
64+
result = subprocess.run(
65+
[
66+
sys.executable,
67+
"-c",
68+
(
69+
"import os\n"
70+
"from google.adk.models.lite_llm import"
71+
" _ensure_litellm_imported\n"
72+
"_ensure_litellm_imported()\n"
73+
"print(os.environ.get('LITELLM_MODE'))\n"
74+
),
75+
],
76+
check=True,
77+
capture_output=True,
78+
text=True,
79+
env=env,
80+
)
81+
assert result.stdout.strip() == "PRODUCTION"
82+
83+
84+
def test_ensure_litellm_imported_does_not_override():
85+
if importlib.util.find_spec("litellm") is None:
86+
pytest.skip("litellm is not installed")
87+
88+
env = _subprocess_env()
89+
env["LITELLM_MODE"] = "DEV"
90+
91+
result = subprocess.run(
92+
[
93+
sys.executable,
94+
"-c",
95+
(
96+
"import os\n"
97+
"from google.adk.models.lite_llm import"
98+
" _ensure_litellm_imported\n"
99+
"_ensure_litellm_imported()\n"
100+
"print(os.environ.get('LITELLM_MODE'))\n"
101+
),
102+
],
103+
check=True,
104+
capture_output=True,
105+
text=True,
106+
env=env,
107+
)
108+
assert result.stdout.strip() == "DEV"

0 commit comments

Comments
 (0)