Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,23 @@ def _convert_response_to_chat_result(self, response: ChatCompletion) -> ChatResu
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
# Anthropic cache tokens (extra fields via Pydantic extra='allow')
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
if cache_creation is not None:
llm_output["usage"]["cache_creation_input_tokens"] = cache_creation
llm_output["cache_creation_input_tokens"] = cache_creation
if cache_read is not None:
llm_output["usage"]["cache_read_input_tokens"] = cache_read
llm_output["cache_read_input_tokens"] = cache_read
# OpenAI cache tokens (standard field)
if response.usage.prompt_tokens_details:
_cached = getattr(
response.usage.prompt_tokens_details, "cached_tokens", None
)
if _cached is not None:
llm_output["usage"]["cached_tokens"] = _cached
llm_output["cached_tokens"] = _cached
# Add individual token counts for backwards compatibility with tests
llm_output["prompt_tokens"] = response.usage.prompt_tokens
llm_output["completion_tokens"] = response.usage.completion_tokens
Expand Down Expand Up @@ -772,11 +789,25 @@ def _extract_completion_usage_from_chunk(
input_tokens = getattr(chunk.usage, "prompt_tokens", None)
output_tokens = getattr(chunk.usage, "completion_tokens", None)
if input_tokens is not None and output_tokens is not None:
return {
usage_dict: Dict[str, int] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
# Anthropic cache tokens (extra fields via Pydantic extra='allow')
cache_creation = getattr(chunk.usage, "cache_creation_input_tokens", None)
cache_read = getattr(chunk.usage, "cache_read_input_tokens", None)
if cache_creation is not None:
usage_dict["cache_creation_input_tokens"] = cache_creation
if cache_read is not None:
usage_dict["cache_read_input_tokens"] = cache_read
# OpenAI cache tokens (standard field)
prompt_details = getattr(chunk.usage, "prompt_tokens_details", None)
if prompt_details:
_cached = getattr(prompt_details, "cached_tokens", None)
if _cached is not None:
usage_dict["cached_tokens"] = _cached
return usage_dict
return None

def _build_usage_chunk_from_completions(
Expand All @@ -801,14 +832,31 @@ def _build_usage_chunk_from_completions(
)
)
else:
input_token_details = {}
cache_creation = usage.get("cache_creation_input_tokens")
cache_read = usage.get("cache_read_input_tokens")
cached_tokens = usage.get("cached_tokens")
if cache_creation is not None:
input_token_details["cache_creation"] = cache_creation
if cache_read is not None:
input_token_details["cache_read"] = cache_read
if cached_tokens is not None:
input_token_details["cache_read"] = cached_tokens

usage_metadata_kwargs = {
"input_tokens": usage["input_tokens"],
"output_tokens": usage["output_tokens"],
"total_tokens": usage["total_tokens"],
}
if input_token_details:
usage_metadata_kwargs["input_token_details"] = InputTokenDetails(
**input_token_details
)

return ChatGenerationChunk(
message=AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=usage["input_tokens"],
output_tokens=usage["output_tokens"],
total_tokens=usage["total_tokens"],
),
usage_metadata=UsageMetadata(**usage_metadata_kwargs),
)
)

Expand Down
127 changes: 127 additions & 0 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def test_chat_model_stream_usage_chunk_emission():
mock_usage = Mock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 5
mock_usage.cache_creation_input_tokens = None
mock_usage.cache_read_input_tokens = None
mock_usage.prompt_tokens_details = None

mock_chunks = [
Mock(
Expand Down Expand Up @@ -310,6 +313,9 @@ def test_chat_model_stream_no_duplicate_usage_chunks():
mock_usage = Mock()
mock_usage.prompt_tokens = 20
mock_usage.completion_tokens = 8
mock_usage.cache_creation_input_tokens = None
mock_usage.cache_read_input_tokens = None
mock_usage.prompt_tokens_details = None

# Multiple chunks with usage data to test the duplicate prevention logic
mock_chunks = [
Expand Down Expand Up @@ -367,6 +373,9 @@ def test_chat_model_stream_usage_only_final_chunk():
mock_usage = Mock()
mock_usage.prompt_tokens = 15
mock_usage.completion_tokens = 10
mock_usage.cache_creation_input_tokens = None
mock_usage.cache_read_input_tokens = None
mock_usage.prompt_tokens_details = None

# Simulate GPT-5 streaming behavior: content chunks followed by usage-only chunk
mock_chunks = [
Expand Down Expand Up @@ -865,6 +874,75 @@ def test_convert_response_to_chat_result_llm_output(llm: ChatDatabricks) -> None
assert usage_metadata["total_tokens"] == _MOCK_CHAT_RESPONSE["usage"]["total_tokens"]


def test_convert_response_to_chat_result_anthropic_cache_tokens(llm: ChatDatabricks) -> None:
"""Test that _convert_response_to_chat_result includes Anthropic cache tokens in llm_output."""
message = ChatCompletionMessage(role="assistant", content="Hello", tool_calls=None)
choice = Choice(index=0, message=message, finish_reason="stop", logprobs=None)
usage = _create_claude_completion_usage()
response = ChatCompletion(
id=_MOCK_CHAT_RESPONSE["id"],
choices=[choice],
created=_MOCK_CHAT_RESPONSE["created"],
model="databricks-claude-sonnet-4-5",
object="chat.completion",
usage=usage,
)

result = llm._convert_response_to_chat_result(response)

# Verify Anthropic cache tokens in llm_output
assert result.llm_output["usage"]["cache_creation_input_tokens"] == 20
assert result.llm_output["usage"]["cache_read_input_tokens"] == 30
assert result.llm_output["cache_creation_input_tokens"] == 20
assert result.llm_output["cache_read_input_tokens"] == 30


def test_convert_response_to_chat_result_openai_cache_tokens(llm: ChatDatabricks) -> None:
"""Test that _convert_response_to_chat_result includes OpenAI cache tokens in llm_output."""
message = ChatCompletionMessage(role="assistant", content="Hello", tool_calls=None)
choice = Choice(index=0, message=message, finish_reason="stop", logprobs=None)
usage = _create_openai_completion_usage()
response = ChatCompletion(
id=_MOCK_CHAT_RESPONSE["id"],
choices=[choice],
created=_MOCK_CHAT_RESPONSE["created"],
model="gpt-4o",
object="chat.completion",
usage=usage,
)

result = llm._convert_response_to_chat_result(response)

# Verify OpenAI cache tokens in llm_output
assert result.llm_output["usage"]["cached_tokens"] == 20
assert result.llm_output["cached_tokens"] == 20


def test_convert_response_to_chat_result_no_cache_tokens(llm: ChatDatabricks) -> None:
"""Test that _convert_response_to_chat_result works without cache tokens."""
message = ChatCompletionMessage(role="assistant", content="Hello", tool_calls=None)
choice = Choice(index=0, message=message, finish_reason="stop", logprobs=None)
usage = CompletionUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
response = ChatCompletion(
id=_MOCK_CHAT_RESPONSE["id"],
choices=[choice],
created=_MOCK_CHAT_RESPONSE["created"],
model="test-model",
object="chat.completion",
usage=usage,
)

result = llm._convert_response_to_chat_result(response)

# Verify no cache tokens in llm_output
assert "cache_creation_input_tokens" not in result.llm_output["usage"]
assert "cache_read_input_tokens" not in result.llm_output["usage"]
assert "cached_tokens" not in result.llm_output["usage"]
assert "cache_creation_input_tokens" not in result.llm_output
assert "cache_read_input_tokens" not in result.llm_output
assert "cached_tokens" not in result.llm_output


def test_convert_lc_messages_to_responses_api_basic():
"""Test _convert_lc_messages_to_responses_api with basic messages."""
messages: list[BaseMessage] = [
Expand Down Expand Up @@ -2060,6 +2138,55 @@ def test_chat_databricks_stream_with_detailed_usage_metadata():
assert usage_metadata["output_token_details"]["reasoning"] == 10


def test_chat_databricks_stream_with_claude_cache_tokens():
"""Test streaming with stream_usage=True includes Claude cache token details."""
with patch("databricks_langchain.chat_models.get_openai_client") as mock_get_client:
mock_client = Mock()
mock_get_client.return_value = mock_client

mock_chunk1 = Mock()
mock_chunk1.choices = [Mock()]
mock_chunk1.choices[0].delta.model_dump.return_value = {
"role": "assistant",
"content": "Hello",
}
mock_chunk1.choices[0].finish_reason = None
mock_chunk1.choices[0].logprobs = None
mock_chunk1.usage = None

# Final chunk with Claude-style usage (cache tokens as extra fields)
claude_usage = _create_claude_completion_usage()
mock_chunk2 = Mock()
mock_chunk2.choices = [Mock()]
mock_chunk2.choices[0].delta.model_dump.return_value = {
"role": "assistant",
"content": " world",
}
mock_chunk2.choices[0].finish_reason = "stop"
mock_chunk2.choices[0].logprobs = None
mock_chunk2.usage = claude_usage

mock_client.chat.completions.create.return_value = iter([mock_chunk1, mock_chunk2])

llm = ChatDatabricks(model="databricks-claude-sonnet-4-5")
chunks = list(llm.stream([HumanMessage(content="Hello")], stream_usage=True))

# Find usage chunk
usage_chunks = [
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
]
assert len(usage_chunks) == 1

usage_chunk = usage_chunks[0]
usage_metadata = usage_chunk.usage_metadata
assert usage_metadata is not None
# Claude sums prompt_tokens + cache_read + cache_creation for input_tokens
assert usage_metadata["input_tokens"] == 150 # 100 + 30 + 20
assert usage_metadata["output_tokens"] == 50
assert usage_metadata["input_token_details"]["cache_read"] == 30
assert usage_metadata["input_token_details"]["cache_creation"] == 20


def test_chat_databricks_responses_api_invoke_returns_usage_metadata():
"""Test that responses API invoke returns AIMessage with usage_metadata."""
with patch("databricks_langchain.chat_models.get_openai_client") as mock_get_client:
Expand Down