-
Notifications
You must be signed in to change notification settings - Fork 23k
feat(groq): map context-length errors to ContextOverflowError
#37676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,12 +6,14 @@ | |
| import warnings | ||
| from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence | ||
| from operator import itemgetter | ||
| from typing import Any, Literal, cast | ||
| from typing import Any, Literal, NoReturn, cast | ||
|
|
||
| import groq | ||
| from langchain_core.callbacks import ( | ||
| AsyncCallbackManagerForLLMRun, | ||
| CallbackManagerForLLMRun, | ||
| ) | ||
| from langchain_core.exceptions import ContextOverflowError | ||
| from langchain_core.language_models import ( | ||
| LanguageModelInput, | ||
| ModelProfile, | ||
|
|
@@ -88,6 +90,28 @@ def _get_default_model_profile(model_name: str) -> ModelProfile: | |
| return default.copy() | ||
|
|
||
|
|
||
| class GroqContextOverflowError(groq.BadRequestError, ContextOverflowError): | ||
| """`BadRequestError` raised when input exceeds Groq's context limit.""" | ||
|
|
||
|
|
||
| def _handle_groq_invalid_request(e: groq.BadRequestError) -> NoReturn: | ||
| """Promote context-length errors to `GroqContextOverflowError`. | ||
|
|
||
| Groq surfaces an over-long prompt as a 400 `BadRequestError` whose body | ||
| carries `"code": "context_length_exceeded"`. The SDK does not expose that | ||
| code as an attribute, so it is matched against the stringified error (which | ||
| includes the JSON body) as well, with `"reduce the length"` from the message | ||
| as a secondary signal. | ||
| """ | ||
| if ( | ||
| getattr(e, "code", None) == "context_length_exceeded" | ||
| or "context_length_exceeded" in str(e) | ||
| or "reduce the length" in str(e) | ||
| ): | ||
| raise GroqContextOverflowError(str(e), response=e.response, body=e.body) from e | ||
| raise e | ||
|
Comment on lines
+106
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I notice lot of developers use single char in open source, having coming other languages like C++, Java, I wouldnt encourage usage of single syllabel for variable names, As u know as Zen of Python, readibility counts, i would use atleast ex to make it more explict that 'e' meant exception
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point in general on readability. I used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am Ok eitherway, as it's not a major issue and one of hardest things to solve in computer science is Naming and the other one is caching so we can leave it there. But a quote came to my mind "A foolish consistency is the hobgoblin of not great minds," |
||
|
|
||
|
|
||
| class ChatGroq(BaseChatModel): | ||
| r"""Groq Chat large language models API. | ||
|
|
||
|
|
@@ -624,7 +648,10 @@ def _generate( | |
| **params, | ||
| **kwargs, | ||
| } | ||
| response = self.client.create(messages=message_dicts, **params) | ||
| try: | ||
| response = self.client.create(messages=message_dicts, **params) | ||
| except groq.BadRequestError as e: | ||
| _handle_groq_invalid_request(e) | ||
| return self._create_chat_result(response, params) | ||
|
|
||
| async def _agenerate( | ||
|
|
@@ -645,7 +672,10 @@ async def _agenerate( | |
| **params, | ||
| **kwargs, | ||
| } | ||
| response = await self.async_client.create(messages=message_dicts, **params) | ||
| try: | ||
| response = await self.async_client.create(messages=message_dicts, **params) | ||
| except groq.BadRequestError as e: | ||
| _handle_groq_invalid_request(e) | ||
| return self._create_chat_result(response, params) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, do we reach line 679 of return statment if there is exception thrown by
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. We do not reach the |
||
|
|
||
| def _stream( | ||
|
|
@@ -659,8 +689,12 @@ def _stream( | |
|
|
||
| params = {**params, **kwargs, "stream": True} | ||
|
|
||
| try: | ||
| stream = self.client.create(messages=message_dicts, **params) | ||
| except groq.BadRequestError as e: | ||
| _handle_groq_invalid_request(e) | ||
| default_chunk_class: type[BaseMessageChunk] = AIMessageChunk | ||
| for chunk in self.client.create(messages=message_dicts, **params): | ||
| for chunk in stream: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so what happens if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, the loop does not run on an exception. |
||
| if not isinstance(chunk, dict): | ||
| chunk = chunk.model_dump() # noqa: PLW2901 | ||
| if len(chunk["choices"]) == 0: | ||
|
|
@@ -711,10 +745,12 @@ async def _astream( | |
|
|
||
| params = {**params, **kwargs, "stream": True} | ||
|
|
||
| try: | ||
| stream = await self.async_client.create(messages=message_dicts, **params) | ||
| except groq.BadRequestError as e: | ||
| _handle_groq_invalid_request(e) | ||
| default_chunk_class: type[BaseMessageChunk] = AIMessageChunk | ||
| async for chunk in await self.async_client.create( | ||
| messages=message_dicts, **params | ||
| ): | ||
| async for chunk in stream: | ||
| if not isinstance(chunk, dict): | ||
| chunk = chunk.model_dump() # noqa: PLW2901 | ||
| if len(chunk["choices"]) == 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,11 @@ | |
| from typing import Any | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import groq | ||
| import httpx | ||
| import langchain_core.load as lc_load | ||
| import pytest | ||
| from langchain_core.exceptions import ContextOverflowError | ||
| from langchain_core.messages import ( | ||
| AIMessage, | ||
| AIMessageChunk, | ||
|
|
@@ -21,6 +24,7 @@ | |
|
|
||
| from langchain_groq.chat_models import ( | ||
| ChatGroq, | ||
| GroqContextOverflowError, | ||
| _convert_chunk_to_message_chunk, | ||
| _convert_dict_to_message, | ||
| _create_usage_metadata, | ||
|
|
@@ -1095,3 +1099,144 @@ def test_format_message_content_mixed() -> None: | |
| {"type": "image_url", "image_url": {"url": "data:image/png;base64,<data>"}}, | ||
| ] | ||
| assert expected == _format_message_content(content) | ||
|
|
||
|
|
||
| def _bad_request_error(body: dict[str, Any], status_code: int = 400) -> Exception: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this has to be private method? looks like it as this is not unit test case but helper method for remaining test cases?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, |
||
| """Build a `groq.BadRequestError` the way the SDK does for a 4xx response. | ||
|
|
||
| The Groq SDK formats the error message as ``Error code: <status> - <body>``, | ||
| so the JSON body (including its ``code`` field) ends up in ``str(e)``. This | ||
| mirrors `groq._base_client.BaseClient._make_status_error_from_response`. | ||
| """ | ||
| request = httpx.Request("POST", "https://api.groq.com/openai/v1/chat/completions") | ||
| response = httpx.Response(status_code=status_code, request=request) | ||
| message = f"Error code: {status_code} - {body}" | ||
| return groq.BadRequestError(message, response=response, body=body.get("error")) | ||
|
|
||
|
|
||
| # Verified shape of a real Groq context-overflow response. See | ||
| # https://console.groq.com/docs/errors and letta-ai/letta#1963. | ||
| _CONTEXT_OVERFLOW_BODY = { | ||
| "error": { | ||
| "message": "Please reduce the length of the messages or completion.", | ||
| "type": "invalid_request_error", | ||
| "param": "messages", | ||
| "code": "context_length_exceeded", | ||
| } | ||
| } | ||
|
Comment on lines
+1119
to
+1126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this has to be private variable? Also, it's in upper case, I would assume if this is some kinda constant? as per convention, not necessarily syntax error though
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a module-level test constant, so the upper-case name follows the PEP 8 convention for constants, and the underscore keeps it scoped to this test module rather than exported. It is fixed input data reused across the promotion tests, so treating it as a named constant rather than rebuilding it in each test keeps the cases readable. Happy to adjust the name if you would rather it not be a constant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am Ok if u are convinced the way it is. Besides, naming is not major issue, it's more of readibility for u and other guys like myself |
||
|
|
||
|
|
||
| def test_context_overflow_error_invoke_sync() -> None: | ||
| """Context-length errors surface as `ContextOverflowError` on invoke.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_client = MagicMock() | ||
| mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
| llm.client = mock_client | ||
|
|
||
| with pytest.raises(ContextOverflowError) as exc_info: | ||
| llm.invoke([HumanMessage(content="test")]) | ||
|
|
||
| assert "context_length_exceeded" in str(exc_info.value) | ||
| assert isinstance(exc_info.value, GroqContextOverflowError) | ||
|
|
||
|
|
||
| async def test_context_overflow_error_invoke_async() -> None: | ||
| """Context-length errors surface as `ContextOverflowError` on ainvoke.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_async = MagicMock() | ||
|
|
||
| async def _create(**_kwargs: Any) -> dict[str, Any]: | ||
| raise _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
|
|
||
| mock_async.create = _create | ||
| llm.async_client = mock_async | ||
|
|
||
| with pytest.raises(ContextOverflowError) as exc_info: | ||
| await llm.ainvoke([HumanMessage(content="test")]) | ||
|
|
||
| assert "context_length_exceeded" in str(exc_info.value) | ||
| assert isinstance(exc_info.value, GroqContextOverflowError) | ||
|
|
||
|
|
||
| def test_context_overflow_error_stream_sync() -> None: | ||
| """Context-length errors surface as `ContextOverflowError` on stream.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_client = MagicMock() | ||
| mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
| llm.client = mock_client | ||
|
|
||
| with pytest.raises(ContextOverflowError) as exc_info: | ||
| list(llm.stream([HumanMessage(content="test")])) | ||
|
|
||
| assert "context_length_exceeded" in str(exc_info.value) | ||
| assert isinstance(exc_info.value, GroqContextOverflowError) | ||
|
|
||
|
|
||
| async def test_context_overflow_error_stream_async() -> None: | ||
| """Context-length errors surface as `ContextOverflowError` on astream.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_async = MagicMock() | ||
|
|
||
| async def _create(**_kwargs: Any) -> Any: | ||
| raise _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
|
|
||
| mock_async.create = _create | ||
| llm.async_client = mock_async | ||
|
|
||
| with pytest.raises(ContextOverflowError) as exc_info: | ||
| async for _ in llm.astream([HumanMessage(content="test")]): | ||
| pass | ||
|
|
||
| assert "context_length_exceeded" in str(exc_info.value) | ||
| assert isinstance(exc_info.value, GroqContextOverflowError) | ||
|
|
||
|
|
||
| def test_context_overflow_error_backwards_compatibility() -> None: | ||
| """`ContextOverflowError` is also catchable as `groq.BadRequestError`.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_client = MagicMock() | ||
| mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
| llm.client = mock_client | ||
|
|
||
| with pytest.raises(groq.BadRequestError) as exc_info: | ||
| llm.invoke([HumanMessage(content="test")]) | ||
|
|
||
| assert isinstance(exc_info.value, groq.BadRequestError) | ||
| assert isinstance(exc_info.value, ContextOverflowError) | ||
|
|
||
|
|
||
| def test_unrelated_invalid_request_error_not_promoted() -> None: | ||
| """Unrelated `BadRequestError`s should stay a plain `BadRequestError`.""" | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| other_error = { | ||
| "error": { | ||
| "message": "Invalid value for 'temperature'.", | ||
| "type": "invalid_request_error", | ||
| "code": "invalid_value", | ||
| } | ||
| } | ||
| mock_client = MagicMock() | ||
| mock_client.create.side_effect = _bad_request_error(other_error) | ||
| llm.client = mock_client | ||
|
|
||
| with pytest.raises(groq.BadRequestError) as exc_info: | ||
| llm.invoke([HumanMessage(content="test")]) | ||
|
|
||
| assert not isinstance(exc_info.value, ContextOverflowError) | ||
|
|
||
|
|
||
| def test_context_overflow_error_carries_response_metadata() -> None: | ||
| """Promoted `GroqContextOverflowError` preserves `response`/`body`. | ||
|
|
||
| Downstream catchers that introspect `.response.status_code` rely on this. | ||
| """ | ||
| llm = ChatGroq(model="foo", max_retries=0) | ||
| mock_client = MagicMock() | ||
| mock_client.create.side_effect = _bad_request_error(_CONTEXT_OVERFLOW_BODY) | ||
| llm.client = mock_client | ||
|
|
||
| with pytest.raises(GroqContextOverflowError) as exc_info: | ||
| llm.invoke([HumanMessage(content="test")]) | ||
|
|
||
| assert exc_info.value.response.status_code == 400 | ||
| assert exc_info.value.body == _CONTEXT_OVERFLOW_BODY["error"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious why this method is made private to this class with '_' to start with, is it so others can't invoke this method and hence made private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the leading underscore marks it as module-private, it is an internal helper that the two
BadRequestErrorhandlers call and is not part of the public API. That matches the existing convention in this file, where the other helpers are written the same way, for example_get_default_model_profile,_create_chat_result, and_convert_dict_to_message. Keeping it private also lets the implementation change later without it being a public surface anyone depends on.