Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,62 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
return base62_str.rjust(9, "0")


def _extract_mistral_citations(content: Any) -> list[dict[str, Any]]:
"""Extract Mistral reference blocks from content."""
if not isinstance(content, list):
return []
return [
{key: value for key, value in block.items() if key != "index"}
for block in content
if isinstance(block, dict) and block.get("type") == "reference"
]


def _normalize_mistral_assistant_content(
raw_content: Any,
) -> tuple[str | list[str | dict], list[dict[str, Any]]]:
"""Normalize Mistral assistant content and extract citation blocks."""
if not isinstance(raw_content, list):
return raw_content or "", []

citations = _extract_mistral_citations(raw_content)
if not citations:
return cast("list[str | dict]", raw_content), []

text_parts: list[str] = []
should_flatten = True
for block in raw_content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict):
if block.get("type") == "reference" or (
block.get("type") == "text" and set(block) <= {"type", "text"}
):
text = block.get("text")
text_parts.append(text if isinstance(text, str) else str(text or ""))
else:
should_flatten = False
else:
should_flatten = False

if should_flatten:
return "".join(text_parts), citations
return cast("list[str | dict]", raw_content), citations


def _convert_mistral_chat_message_to_message(
_message: dict,
) -> BaseMessage:
role = _message["role"]
if role != "assistant":
msg = f"Expected role to be 'assistant', got {role}"
raise ValueError(msg)
# Mistral returns None for tool invocations
content = _message.get("content", "") or ""
# Mistral returns None for tool invocations. It can also return typed content
# blocks for citations; keep the answer text backward compatible and surface
# citation metadata separately.
content, citations = _normalize_mistral_assistant_content(
_message.get("content", "")
)

additional_kwargs: dict = {}
tool_calls = []
Expand All @@ -166,12 +213,15 @@ def _convert_mistral_chat_message_to_message(
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
response_metadata: dict[str, Any] = {"model_provider": "mistralai"}
if citations:
response_metadata["citations"] = citations
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "mistralai"},
response_metadata=response_metadata,
)


Expand Down Expand Up @@ -255,6 +305,7 @@ def _convert_chunk_to_message_chunk(
content = _delta.get("content") or ""
if output_version == "v1" and isinstance(content, str):
content = [{"type": "text", "text": content}]
citations = _extract_mistral_citations(content)
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
Expand All @@ -273,7 +324,9 @@ def _convert_chunk_to_message_chunk(
return HumanMessageChunk(content=content), index, index_type
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
response_metadata: dict[str, Any] = {}
if citations:
response_metadata["citations"] = citations
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
Expand Down
91 changes: 91 additions & 0 deletions libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
Expand All @@ -21,6 +22,7 @@

from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_chunk_to_message_chunk,
_convert_message_to_mistral_chat_message,
_convert_mistral_chat_message_to_message,
_convert_tool_call_id_to_mistral_compatible,
Expand Down Expand Up @@ -290,6 +292,95 @@ def test__convert_dict_to_message_with_missing_content() -> None:
assert result == expected_output


def test__convert_dict_to_message_with_citations() -> None:
cited_text = "the temperature is 20 degrees C"
expected_citation = {
"type": "reference",
"reference_ids": [0],
"text": cited_text,
}
citation_content = [
{"type": "text", "text": "According to the document, "},
expected_citation,
{"type": "text", "text": " on average."},
]
message = {"role": "assistant", "content": citation_content}
result = _convert_mistral_chat_message_to_message(message)

assert result.content == (
"According to the document, the temperature is 20 degrees C on average."
)
assert result.response_metadata["model_provider"] == "mistralai"
assert result.response_metadata["citations"] == [expected_citation]


def test_create_chat_result_with_citations() -> None:
chat = ChatMistralAI()
expected_citation = {"type": "reference", "reference_ids": [0], "text": "42"}
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "The answer is "},
expected_citation,
{"type": "text", "text": "."},
],
},
"finish_reason": "stop",
}
]
}

result = chat._create_chat_result(response)
message = result.generations[0].message

assert message.content == "The answer is 42."
assert message.response_metadata["citations"] == [expected_citation]


def test__convert_chunk_to_message_chunk_with_citations() -> None:
expected_citation = {"type": "reference", "reference_ids": [0], "text": "42"}
text_chunk = {
"choices": [
{
"delta": {"role": "assistant", "content": "The answer is "},
"finish_reason": None,
}
],
}
reference_chunk = {
"choices": [
{
"delta": {
"role": "assistant",
"content": [
dict(expected_citation),
],
},
"finish_reason": "stop",
}
],
"model": "mistral-small-latest",
}

result_1, index, index_type = _convert_chunk_to_message_chunk(
text_chunk, AIMessageChunk, -1, "", None
)
result_2, _, _ = _convert_chunk_to_message_chunk(
reference_chunk, AIMessageChunk, index, index_type, None
)

assert isinstance(result_2, AIMessageChunk)
assert result_2.response_metadata["citations"] == [expected_citation]

full = result_1 + result_2
assert isinstance(full, AIMessageChunk)
assert full.response_metadata["citations"] == [expected_citation]
assert full.response_metadata["finish_reason"] == "stop"


def test_custom_token_counting() -> None:
def token_encoder(text: str) -> list[int]:
return [1, 2, 3]
Expand Down
Loading