Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import Any, Literal, cast
from typing import Any, Literal, NoReturn, cast

import groq
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import (
LanguageModelInput,
ModelProfile,
Expand Down Expand Up @@ -88,6 +90,28 @@ def _get_default_model_profile(model_name: str) -> ModelProfile:
return default.copy()


class GroqContextOverflowError(groq.BadRequestError, ContextOverflowError):
"""`BadRequestError` raised when input exceeds Groq's context limit."""


def _handle_groq_invalid_request(e: groq.BadRequestError) -> NoReturn:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious why this method is made private to this class with '_' to start with, is it so others can't invoke this method and hence made private?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the leading underscore marks it as module-private, it is an internal helper that the two BadRequestError handlers call and is not part of the public API. That matches the existing convention in this file, where the other helpers are written the same way, for example _get_default_model_profile, _create_chat_result, and _convert_dict_to_message. Keeping it private also lets the implementation change later without it being a public surface anyone depends on.

"""Promote context-length errors to `GroqContextOverflowError`.

Groq surfaces an over-long prompt as a 400 `BadRequestError` whose body
carries `"code": "context_length_exceeded"`. The SDK does not expose that
code as an attribute, so it is matched against the stringified error (which
includes the JSON body) as well, with `"reduce the length"` from the message
as a secondary signal.
"""
if (
getattr(e, "code", None) == "context_length_exceeded"
or "context_length_exceeded" in str(e)
or "reduce the length" in str(e)
):
raise GroqContextOverflowError(str(e), response=e.response, body=e.body) from e
raise e
Comment on lines +106 to +112

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice lot of developers use single char in open source, having coming other languages like C++, Java, I wouldnt encourage usage of single syllabel for variable names, As u know as Zen of Python, readibility counts, i would use atleast ex to make it more explict that 'e' meant exception

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point in general on readability. I used e here to stay consistent with the rest of this file and the sibling partner integrations I mirrored, the existing except Exception as e at the bottom of this module uses the same name. The scope is two lines and the type is right there in the signature (e: groq.BadRequestError), so the meaning stays clear. Happy to rename to exc if the maintainers prefer that, the file also has one except ImportError as exc, so there is a mild mix already.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am Ok eitherway, as it's not a major issue and one of hardest things to solve in computer science is Naming and the other one is caching so we can leave it there.

But a quote came to my mind "A foolish consistency is the hobgoblin of not great minds,"



class ChatGroq(BaseChatModel):
r"""Groq Chat large language models API.

Expand Down Expand Up @@ -624,7 +648,10 @@ def _generate(
**params,
**kwargs,
}
response = self.client.create(messages=message_dicts, **params)
try:
response = self.client.create(messages=message_dicts, **params)
except groq.BadRequestError as e:
_handle_groq_invalid_request(e)
return self._create_chat_result(response, params)

async def _agenerate(
Expand All @@ -645,7 +672,10 @@ async def _agenerate(
**params,
**kwargs,
}
response = await self.async_client.create(messages=message_dicts, **params)
try:
response = await self.async_client.create(messages=message_dicts, **params)
except groq.BadRequestError as e:
_handle_groq_invalid_request(e)
return self._create_chat_result(response, params)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, do we reach line 679 of return statment if there is exception thrown by async_client.create() or do we need finally?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We do not reach the return on a context-overflow error. _handle_groq_invalid_request is annotated -> NoReturn and always raises: it either raises the promoted GroqContextOverflowError or re-raises the original BadRequestError. So when create() throws a BadRequestError, the except calls the helper, the helper raises, and control never falls through to _create_chat_result. No finally is needed, and we deliberately do not want one here because there is no resource to clean up, response is simply never assigned on the error path.


def _stream(
Expand All @@ -659,8 +689,12 @@ def _stream(

params = {**params, **kwargs, "stream": True}

try:
stream = self.client.create(messages=message_dicts, **params)
except groq.BadRequestError as e:
_handle_groq_invalid_request(e)
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
for chunk in stream:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what happens if client.create() throws exception what would be stream here and probably loop wont be executed right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the loop does not run on an exception. self.client.create(...) is called inside the try, and if it raises BadRequestError the except calls _handle_groq_invalid_request, which is NoReturn and always raises. So execution leaves the function before stream is ever bound, and for chunk in stream is never reached. stream only gets a value when create() returned normally, so there is no risk of iterating an unbound or partial stream.

if not isinstance(chunk, dict):
chunk = chunk.model_dump() # noqa: PLW2901
if len(chunk["choices"]) == 0:
Expand Down Expand Up @@ -711,10 +745,12 @@ async def _astream(

params = {**params, **kwargs, "stream": True}

try:
stream = await self.async_client.create(messages=message_dicts, **params)
except groq.BadRequestError as e:
_handle_groq_invalid_request(e)
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await self.async_client.create(
messages=message_dicts, **params
):
async for chunk in stream:
if not isinstance(chunk, dict):
chunk = chunk.model_dump() # noqa: PLW2901
if len(chunk["choices"]) == 0:
Expand Down
145 changes: 145 additions & 0 deletions libs/partners/groq/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import groq
import httpx
import langchain_core.load as lc_load
import pytest
from langchain_core.exceptions import ContextOverflowError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand All @@ -21,6 +24,7 @@

from langchain_groq.chat_models import (
ChatGroq,
GroqContextOverflowError,
_convert_chunk_to_message_chunk,
_convert_dict_to_message,
_create_usage_metadata,
Expand Down Expand Up @@ -1095,3 +1099,144 @@ def test_format_message_content_mixed() -> None:
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<data>"}},
]
assert expected == _format_message_content(content)


def _bad_request_error(body: dict[str, Any], status_code: int = 400) -> Exception:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this has to be private method? looks like it as this is not unit test case but helper method for remaining test cases?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, _bad_request_error is a fixture builder, not a test case, so the underscore keeps pytest from collecting it as a test and signals it is a local helper for the tests below. pytest discovers test_-prefixed functions, so naming it _bad_request_error avoids it being mistaken for one while still being importable within the module.

"""Build a `groq.BadRequestError` the way the SDK does for a 4xx response.

The Groq SDK formats the error message as ``Error code: <status> - <body>``,
so the JSON body (including its ``code`` field) ends up in ``str(e)``. This
mirrors `groq._base_client.BaseClient._make_status_error_from_response`.
"""
request = httpx.Request("POST", "https://api.groq.com/openai/v1/chat/completions")
response = httpx.Response(status_code=status_code, request=request)
message = f"Error code: {status_code} - {body}"
return groq.BadRequestError(message, response=response, body=body.get("error"))


# Verified shape of a real Groq context-overflow response. See
# https://console.groq.com/docs/errors and letta-ai/letta#1963.
_CONTEXT_OVERFLOW_BODY = {
"error": {
"message": "Please reduce the length of the messages or completion.",
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded",
}
}
Comment on lines +1119 to +1126

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this has to be private variable? Also, it's in upper case, I would assume if this is some kinda constant? as per convention, not necessarily syntax error though

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a module-level test constant, so the upper-case name follows the PEP 8 convention for constants, and the underscore keeps it scoped to this test module rather than exported. It is fixed input data reused across the promotion tests, so treating it as a named constant rather than rebuilding it in each test keeps the cases readable. Happy to adjust the name if you would rather it not be a constant.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am Ok if u are convinced the way it is. Besides, naming is not major issue, it's more of readibility for u and other guys like myself



def test_context_overflow_error_invoke_sync() -> None:
"""Context-length errors surface as `ContextOverflowError` on invoke."""
llm = ChatGroq(model="foo", max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY)
llm.client = mock_client

with pytest.raises(ContextOverflowError) as exc_info:
llm.invoke([HumanMessage(content="test")])

assert "context_length_exceeded" in str(exc_info.value)
assert isinstance(exc_info.value, GroqContextOverflowError)


async def test_context_overflow_error_invoke_async() -> None:
"""Context-length errors surface as `ContextOverflowError` on ainvoke."""
llm = ChatGroq(model="foo", max_retries=0)
mock_async = MagicMock()

async def _create(**_kwargs: Any) -> dict[str, Any]:
raise _bad_request_error(_CONTEXT_OVERFLOW_BODY)

mock_async.create = _create
llm.async_client = mock_async

with pytest.raises(ContextOverflowError) as exc_info:
await llm.ainvoke([HumanMessage(content="test")])

assert "context_length_exceeded" in str(exc_info.value)
assert isinstance(exc_info.value, GroqContextOverflowError)


def test_context_overflow_error_stream_sync() -> None:
"""Context-length errors surface as `ContextOverflowError` on stream."""
llm = ChatGroq(model="foo", max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY)
llm.client = mock_client

with pytest.raises(ContextOverflowError) as exc_info:
list(llm.stream([HumanMessage(content="test")]))

assert "context_length_exceeded" in str(exc_info.value)
assert isinstance(exc_info.value, GroqContextOverflowError)


async def test_context_overflow_error_stream_async() -> None:
"""Context-length errors surface as `ContextOverflowError` on astream."""
llm = ChatGroq(model="foo", max_retries=0)
mock_async = MagicMock()

async def _create(**_kwargs: Any) -> Any:
raise _bad_request_error(_CONTEXT_OVERFLOW_BODY)

mock_async.create = _create
llm.async_client = mock_async

with pytest.raises(ContextOverflowError) as exc_info:
async for _ in llm.astream([HumanMessage(content="test")]):
pass

assert "context_length_exceeded" in str(exc_info.value)
assert isinstance(exc_info.value, GroqContextOverflowError)


def test_context_overflow_error_backwards_compatibility() -> None:
"""`ContextOverflowError` is also catchable as `groq.BadRequestError`."""
llm = ChatGroq(model="foo", max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY)
llm.client = mock_client

with pytest.raises(groq.BadRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])

assert isinstance(exc_info.value, groq.BadRequestError)
assert isinstance(exc_info.value, ContextOverflowError)


def test_unrelated_invalid_request_error_not_promoted() -> None:
"""Unrelated `BadRequestError`s should stay a plain `BadRequestError`."""
llm = ChatGroq(model="foo", max_retries=0)
other_error = {
"error": {
"message": "Invalid value for 'temperature'.",
"type": "invalid_request_error",
"code": "invalid_value",
}
}
mock_client = MagicMock()
mock_client.create.side_effect = _bad_request_error(other_error)
llm.client = mock_client

with pytest.raises(groq.BadRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])

assert not isinstance(exc_info.value, ContextOverflowError)


def test_context_overflow_error_carries_response_metadata() -> None:
"""Promoted `GroqContextOverflowError` preserves `response`/`body`.

Downstream catchers that introspect `.response.status_code` rely on this.
"""
llm = ChatGroq(model="foo", max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY)
llm.client = mock_client

with pytest.raises(GroqContextOverflowError) as exc_info:
llm.invoke([HumanMessage(content="test")])

assert exc_info.value.response.status_code == 400
assert exc_info.value.body == _CONTEXT_OVERFLOW_BODY["error"]
Loading