Skip to content

Commit ab4dab9

Browse files
author
Erick Friis
authored
core: fix batch race condition in FakeListChatModel (#26924)
fixed #26273
1 parent 87fc5ce commit ab4dab9

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

libs/core/langchain_core/language_models/fake_chat_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
1414
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
1515
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
16+
from langchain_core.runnables import RunnableConfig
1617

1718

1819
class FakeMessagesListChatModel(BaseChatModel):
@@ -128,6 +129,33 @@ async def _astream(
128129
def _identifying_params(self) -> dict[str, Any]:
129130
return {"responses": self.responses}
130131

132+
# manually override batch to preserve batch ordering with no concurrency
133+
def batch(
134+
self,
135+
inputs: list[Any],
136+
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
137+
*,
138+
return_exceptions: bool = False,
139+
**kwargs: Any,
140+
) -> list[BaseMessage]:
141+
if isinstance(config, list):
142+
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
143+
return [self.invoke(m, config, **kwargs) for m in inputs]
144+
145+
async def abatch(
146+
self,
147+
inputs: list[Any],
148+
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
149+
*,
150+
return_exceptions: bool = False,
151+
**kwargs: Any,
152+
) -> list[BaseMessage]:
153+
if isinstance(config, list):
154+
# do Not use an async iterator here because need explicit ordering
155+
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
156+
# do Not use an async iterator here because need explicit ordering
157+
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
158+
131159

132160
class FakeChatModel(SimpleChatModel):
133161
"""Fake Chat Model wrapper for testing purposes."""

libs/core/tests/unit_tests/fake/test_fake_chat_model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from uuid import UUID
66

77
from langchain_core.callbacks.base import AsyncCallbackHandler
8-
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
8+
from langchain_core.language_models import (
9+
FakeListChatModel,
10+
GenericFakeChatModel,
11+
ParrotFakeChatModel,
12+
)
913
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
1014
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
1115
from tests.unit_tests.stubs import (
@@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None:
205209
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
206210
content="blah"
207211
)
212+
213+
214+
def test_fake_list_chat_model_batch() -> None:
215+
expected = [
216+
_any_id_ai_message(content="a"),
217+
_any_id_ai_message(content="b"),
218+
_any_id_ai_message(content="c"),
219+
]
220+
for _ in range(20):
221+
# run this 20 times to test race condition in batch
222+
fake = FakeListChatModel(responses=["a", "b", "c"])
223+
resp = fake.batch(["1", "2", "3"])
224+
assert resp == expected

libs/core/tests/unit_tests/language_models/chat_models/test_cache.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,13 @@ async def test_global_cache_abatch() -> None:
199199
assert results[0].content == "hello"
200200
assert results[1].content == "hello"
201201

202-
## RACE CONDITION -- note behavior is different from sync
203-
# Now, reset cache and test the race condition
204-
# For now we just hard-code the result, if this changes
205-
# we can investigate further
206202
global_cache = InMemoryCache()
207203
set_llm_cache(global_cache)
208204
assert global_cache._cache == {}
209205
results = await chat_model.abatch(["prompt", "prompt"])
210-
# suspecting that tasks will be scheduled and executed in order
211-
# if this ever fails, we can relax to a set comparison
212-
# Cache misses likely guaranteed?
206+
213207
assert results[0].content == "meow"
214-
assert results[1].content == "woof"
208+
assert results[1].content == "meow"
215209
finally:
216210
set_llm_cache(None)
217211

0 commit comments

Comments
 (0)