Skip to content

Commit 9eb80d5

Browse files
committed
fix: guard litellm modify_params context
1 parent dc3528c commit 9eb80d5

2 files changed

Lines changed: 66 additions & 6 deletions

File tree

openhands-sdk/openhands/sdk/llm/llm.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import json
55
import os
6+
import threading
67
import warnings
78
from collections.abc import Callable, Sequence
89
from contextlib import contextmanager
@@ -456,6 +457,7 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin):
456457
_is_subscription: bool = PrivateAttr(default=False)
457458
_litellm_provider: str | None = PrivateAttr(default=None)
458459
_prompt_cache_key: str | None = PrivateAttr(default=None)
460+
_litellm_modify_params_lock: ClassVar[threading.RLock] = threading.RLock()
459461

460462
model_config: ClassVar[ConfigDict] = ConfigDict(
461463
extra="ignore", arbitrary_types_allowed=True
@@ -1200,12 +1202,13 @@ def _transport_call(
12001202

12011203
@contextmanager
12021204
def _litellm_modify_params_ctx(self, flag: bool):
1203-
old = getattr(litellm, "modify_params", None)
1204-
try:
1205-
litellm.modify_params = flag
1206-
yield
1207-
finally:
1208-
litellm.modify_params = old
1205+
with self._litellm_modify_params_lock:
1206+
old = getattr(litellm, "modify_params", None)
1207+
try:
1208+
litellm.modify_params = flag
1209+
yield
1210+
finally:
1211+
litellm.modify_params = old
12091212

12101213
# =========================================================================
12111214
# Capabilities, formatting, and info

tests/sdk/llm/test_llm_completion.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for LLM completion functionality, configuration, and metrics tracking."""
22

3+
import threading
34
from collections.abc import Sequence
45
from typing import Any, ClassVar
56
from unittest.mock import MagicMock, patch
@@ -19,6 +20,7 @@
1920
)
2021
from pydantic import SecretStr
2122

23+
import openhands.sdk.llm.llm as llm_module
2224
from openhands.sdk.llm import (
2325
LLM,
2426
Message,
@@ -76,6 +78,61 @@ def default_config():
7678
)
7779

7880

81+
def test_litellm_modify_params_context_serializes_threads():
82+
first_llm = LLM.model_construct(modify_params=True)
83+
second_llm = LLM.model_construct(modify_params=False)
84+
original = getattr(llm_module.litellm, "modify_params", None)
85+
86+
entered_first = threading.Event()
87+
release_first = threading.Event()
88+
started_second = threading.Event()
89+
entered_second = threading.Event()
90+
observed: list[tuple[str, bool]] = []
91+
errors: list[BaseException] = []
92+
93+
def run_first():
94+
try:
95+
with first_llm._litellm_modify_params_ctx(True):
96+
observed.append(("first", llm_module.litellm.modify_params))
97+
entered_first.set()
98+
release_first.wait(timeout=2)
99+
except BaseException as exc:
100+
errors.append(exc)
101+
102+
def run_second():
103+
entered_first.wait(timeout=2)
104+
started_second.set()
105+
try:
106+
with second_llm._litellm_modify_params_ctx(False):
107+
observed.append(("second", llm_module.litellm.modify_params))
108+
entered_second.set()
109+
except BaseException as exc:
110+
errors.append(exc)
111+
112+
first_thread = threading.Thread(target=run_first)
113+
second_thread = threading.Thread(target=run_second)
114+
try:
115+
first_thread.start()
116+
assert entered_first.wait(timeout=2)
117+
118+
second_thread.start()
119+
assert started_second.wait(timeout=2)
120+
assert not entered_second.wait(timeout=0.2)
121+
122+
release_first.set()
123+
first_thread.join(timeout=2)
124+
second_thread.join(timeout=2)
125+
finally:
126+
release_first.set()
127+
llm_module.litellm.modify_params = original
128+
129+
assert not first_thread.is_alive()
130+
assert not second_thread.is_alive()
131+
assert errors == []
132+
assert observed == [("first", True), ("second", False)]
133+
assert llm_module.litellm.modify_params == original
134+
135+
79136
@patch("openhands.sdk.llm.llm.litellm_completion")
80137
def test_llm_completion_basic(mock_completion):
81138
"""Test basic LLM completion functionality."""

0 commit comments

Comments
 (0)