|
1 | 1 | """Tests for LLM completion functionality, configuration, and metrics tracking.""" |
2 | 2 |
|
| 3 | +import threading |
3 | 4 | from collections.abc import Sequence |
4 | 5 | from typing import Any, ClassVar |
5 | 6 | from unittest.mock import MagicMock, patch |
|
19 | 20 | ) |
20 | 21 | from pydantic import SecretStr |
21 | 22 |
|
| 23 | +import openhands.sdk.llm.llm as llm_module |
22 | 24 | from openhands.sdk.llm import ( |
23 | 25 | LLM, |
24 | 26 | Message, |
@@ -76,6 +78,61 @@ def default_config(): |
76 | 78 | ) |
77 | 79 |
|
78 | 80 |
|
| 81 | +def test_litellm_modify_params_context_serializes_threads(): |
| 82 | + first_llm = LLM.model_construct(modify_params=True) |
| 83 | + second_llm = LLM.model_construct(modify_params=False) |
| 84 | + original = getattr(llm_module.litellm, "modify_params", None) |
| 85 | + |
| 86 | + entered_first = threading.Event() |
| 87 | + release_first = threading.Event() |
| 88 | + started_second = threading.Event() |
| 89 | + entered_second = threading.Event() |
| 90 | + observed: list[tuple[str, bool]] = [] |
| 91 | + errors: list[BaseException] = [] |
| 92 | + |
| 93 | + def run_first(): |
| 94 | + try: |
| 95 | + with first_llm._litellm_modify_params_ctx(True): |
| 96 | + observed.append(("first", llm_module.litellm.modify_params)) |
| 97 | + entered_first.set() |
| 98 | + release_first.wait(timeout=2) |
| 99 | + except BaseException as exc: |
| 100 | + errors.append(exc) |
| 101 | + |
| 102 | + def run_second(): |
| 103 | + entered_first.wait(timeout=2) |
| 104 | + started_second.set() |
| 105 | + try: |
| 106 | + with second_llm._litellm_modify_params_ctx(False): |
| 107 | + observed.append(("second", llm_module.litellm.modify_params)) |
| 108 | + entered_second.set() |
| 109 | + except BaseException as exc: |
| 110 | + errors.append(exc) |
| 111 | + |
| 112 | + first_thread = threading.Thread(target=run_first) |
| 113 | + second_thread = threading.Thread(target=run_second) |
| 114 | + try: |
| 115 | + first_thread.start() |
| 116 | + assert entered_first.wait(timeout=2) |
| 117 | + |
| 118 | + second_thread.start() |
| 119 | + assert started_second.wait(timeout=2) |
| 120 | + assert not entered_second.wait(timeout=0.2) |
| 121 | + |
| 122 | + release_first.set() |
| 123 | + first_thread.join(timeout=2) |
| 124 | + second_thread.join(timeout=2) |
| 125 | + finally: |
| 126 | + release_first.set() |
| 127 | + llm_module.litellm.modify_params = original |
| 128 | + |
| 129 | + assert not first_thread.is_alive() |
| 130 | + assert not second_thread.is_alive() |
| 131 | + assert errors == [] |
| 132 | + assert observed == [("first", True), ("second", False)] |
| 133 | + assert llm_module.litellm.modify_params == original |
| 134 | + |
| 135 | + |
79 | 136 | @patch("openhands.sdk.llm.llm.litellm_completion") |
80 | 137 | def test_llm_completion_basic(mock_completion): |
81 | 138 | """Test basic LLM completion functionality.""" |
|
0 commit comments