Skip to content

Commit c779973

Browse files
committed
feat(mistralai): surface citation metadata
1 parent b3dff4a commit c779973

2 files changed

Lines changed: 148 additions & 4 deletions

File tree

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,62 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
141141
return base62_str.rjust(9, "0")
142142

143143

144+
def _extract_mistral_citations(content: Any) -> list[dict[str, Any]]:
145+
"""Extract Mistral reference blocks from content."""
146+
if not isinstance(content, list):
147+
return []
148+
return [
149+
{key: value for key, value in block.items() if key != "index"}
150+
for block in content
151+
if isinstance(block, dict) and block.get("type") == "reference"
152+
]
153+
154+
155+
def _normalize_mistral_assistant_content(
156+
raw_content: Any,
157+
) -> tuple[str | list[str | dict], list[dict[str, Any]]]:
158+
"""Normalize Mistral assistant content and extract citation blocks."""
159+
if not isinstance(raw_content, list):
160+
return raw_content or "", []
161+
162+
citations = _extract_mistral_citations(raw_content)
163+
if not citations:
164+
return cast("list[str | dict]", raw_content), []
165+
166+
text_parts: list[str] = []
167+
should_flatten = True
168+
for block in raw_content:
169+
if isinstance(block, str):
170+
text_parts.append(block)
171+
elif isinstance(block, dict):
172+
if block.get("type") == "reference" or (
173+
block.get("type") == "text" and set(block) <= {"type", "text"}
174+
):
175+
text = block.get("text")
176+
text_parts.append(text if isinstance(text, str) else str(text or ""))
177+
else:
178+
should_flatten = False
179+
else:
180+
should_flatten = False
181+
182+
if should_flatten:
183+
return "".join(text_parts), citations
184+
return cast("list[str | dict]", raw_content), citations
185+
186+
144187
def _convert_mistral_chat_message_to_message(
145188
_message: dict,
146189
) -> BaseMessage:
147190
role = _message["role"]
148191
if role != "assistant":
149192
msg = f"Expected role to be 'assistant', got {role}"
150193
raise ValueError(msg)
151-
# Mistral returns None for tool invocations
152-
content = _message.get("content", "") or ""
194+
# Mistral returns None for tool invocations. It can also return typed content
195+
# blocks for citations; keep the answer text backward compatible and surface
196+
# citation metadata separately.
197+
content, citations = _normalize_mistral_assistant_content(
198+
_message.get("content", "")
199+
)
153200

154201
additional_kwargs: dict = {}
155202
tool_calls = []
@@ -166,12 +213,15 @@ def _convert_mistral_chat_message_to_message(
166213
tool_calls.append(parsed)
167214
except Exception as e:
168215
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
216+
response_metadata: dict[str, Any] = {"model_provider": "mistralai"}
217+
if citations:
218+
response_metadata["citations"] = citations
169219
return AIMessage(
170220
content=content,
171221
additional_kwargs=additional_kwargs,
172222
tool_calls=tool_calls,
173223
invalid_tool_calls=invalid_tool_calls,
174-
response_metadata={"model_provider": "mistralai"},
224+
response_metadata=response_metadata,
175225
)
176226

177227

@@ -255,6 +305,7 @@ def _convert_chunk_to_message_chunk(
255305
content = _delta.get("content") or ""
256306
if output_version == "v1" and isinstance(content, str):
257307
content = [{"type": "text", "text": content}]
308+
citations = _extract_mistral_citations(content)
258309
if isinstance(content, list):
259310
for block in content:
260311
if isinstance(block, dict):
@@ -273,7 +324,9 @@ def _convert_chunk_to_message_chunk(
273324
return HumanMessageChunk(content=content), index, index_type
274325
if role == "assistant" or default_class == AIMessageChunk:
275326
additional_kwargs: dict = {}
276-
response_metadata = {}
327+
response_metadata: dict[str, Any] = {}
328+
if citations:
329+
response_metadata["citations"] = citations
277330
if raw_tool_calls := _delta.get("tool_calls"):
278331
additional_kwargs["tool_calls"] = raw_tool_calls
279332
try:

libs/partners/mistralai/tests/unit_tests/test_chat_models.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from langchain_core.callbacks.base import BaseCallbackHandler
1111
from langchain_core.messages import (
1212
AIMessage,
13+
AIMessageChunk,
1314
BaseMessage,
1415
ChatMessage,
1516
HumanMessage,
@@ -21,6 +22,7 @@
2122

2223
from langchain_mistralai.chat_models import ( # type: ignore[import]
2324
ChatMistralAI,
25+
_convert_chunk_to_message_chunk,
2426
_convert_message_to_mistral_chat_message,
2527
_convert_mistral_chat_message_to_message,
2628
_convert_tool_call_id_to_mistral_compatible,
@@ -290,6 +292,95 @@ def test__convert_dict_to_message_with_missing_content() -> None:
290292
assert result == expected_output
291293

292294

295+
def test__convert_dict_to_message_with_citations() -> None:
296+
cited_text = "the temperature is 20 degrees C"
297+
expected_citation = {
298+
"type": "reference",
299+
"reference_ids": [0],
300+
"text": cited_text,
301+
}
302+
citation_content = [
303+
{"type": "text", "text": "According to the document, "},
304+
expected_citation,
305+
{"type": "text", "text": " on average."},
306+
]
307+
message = {"role": "assistant", "content": citation_content}
308+
result = _convert_mistral_chat_message_to_message(message)
309+
310+
assert result.content == (
311+
"According to the document, the temperature is 20 degrees C on average."
312+
)
313+
assert result.response_metadata["model_provider"] == "mistralai"
314+
assert result.response_metadata["citations"] == [expected_citation]
315+
316+
317+
def test_create_chat_result_with_citations() -> None:
318+
chat = ChatMistralAI()
319+
expected_citation = {"type": "reference", "reference_ids": [0], "text": "42"}
320+
response = {
321+
"choices": [
322+
{
323+
"message": {
324+
"role": "assistant",
325+
"content": [
326+
{"type": "text", "text": "The answer is "},
327+
expected_citation,
328+
{"type": "text", "text": "."},
329+
],
330+
},
331+
"finish_reason": "stop",
332+
}
333+
]
334+
}
335+
336+
result = chat._create_chat_result(response)
337+
message = result.generations[0].message
338+
339+
assert message.content == "The answer is 42."
340+
assert message.response_metadata["citations"] == [expected_citation]
341+
342+
343+
def test__convert_chunk_to_message_chunk_with_citations() -> None:
344+
expected_citation = {"type": "reference", "reference_ids": [0], "text": "42"}
345+
text_chunk = {
346+
"choices": [
347+
{
348+
"delta": {"role": "assistant", "content": "The answer is "},
349+
"finish_reason": None,
350+
}
351+
],
352+
}
353+
reference_chunk = {
354+
"choices": [
355+
{
356+
"delta": {
357+
"role": "assistant",
358+
"content": [
359+
dict(expected_citation),
360+
],
361+
},
362+
"finish_reason": "stop",
363+
}
364+
],
365+
"model": "mistral-small-latest",
366+
}
367+
368+
result_1, index, index_type = _convert_chunk_to_message_chunk(
369+
text_chunk, AIMessageChunk, -1, "", None
370+
)
371+
result_2, _, _ = _convert_chunk_to_message_chunk(
372+
reference_chunk, AIMessageChunk, index, index_type, None
373+
)
374+
375+
assert isinstance(result_2, AIMessageChunk)
376+
assert result_2.response_metadata["citations"] == [expected_citation]
377+
378+
full = result_1 + result_2
379+
assert isinstance(full, AIMessageChunk)
380+
assert full.response_metadata["citations"] == [expected_citation]
381+
assert full.response_metadata["finish_reason"] == "stop"
382+
383+
293384
def test_custom_token_counting() -> None:
294385
def token_encoder(text: str) -> list[int]:
295386
return [1, 2, 3]

0 commit comments

Comments
 (0)