Skip to content

Commit

Permalink
core: fix batch race condition in FakeListChatModel (#26924)
Browse files Browse the repository at this point in the history
fixed #26273
  • Loading branch information
efriis authored Oct 3, 2024
1 parent 87fc5ce commit ab4dab9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
28 changes: 28 additions & 0 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig


class FakeMessagesListChatModel(BaseChatModel):
Expand Down Expand Up @@ -128,6 +129,33 @@ async def _astream(
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}

# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]

async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]


class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
Expand Down
19 changes: 18 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from uuid import UUID

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.language_models import (
FakeListChatModel,
GenericFakeChatModel,
ParrotFakeChatModel,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
Expand Down Expand Up @@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None:
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
content="blah"
)


def test_fake_list_chat_model_batch() -> None:
expected = [
_any_id_ai_message(content="a"),
_any_id_ai_message(content="b"),
_any_id_ai_message(content="c"),
]
for _ in range(20):
# run this 20 times to test race condition in batch
fake = FakeListChatModel(responses=["a", "b", "c"])
resp = fake.batch(["1", "2", "3"])
assert resp == expected
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,13 @@ async def test_global_cache_abatch() -> None:
assert results[0].content == "hello"
assert results[1].content == "hello"

## RACE CONDITION -- note behavior is different from sync
# Now, reset cache and test the race condition
# For now we just hard-code the result, if this changes
# we can investigate further
global_cache = InMemoryCache()
set_llm_cache(global_cache)
assert global_cache._cache == {}
results = await chat_model.abatch(["prompt", "prompt"])
# suspecting that tasks will be scheduled and executed in order
# if this ever fails, we can relax to a set comparison
# Cache misses likely guaranteed?

assert results[0].content == "meow"
assert results[1].content == "woof"
assert results[1].content == "meow"
finally:
set_llm_cache(None)

Expand Down

0 comments on commit ab4dab9

Please sign in to comment.