Skip to content

Commit 18d16c8

Browse files
authored
Added finish_reason to LLMResult (#410)
1 parent 3267b50 commit 18d16c8

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

packages/lmi/src/lmi/llms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult
10441044
cost=cost,
10451045
system_fingerprint=completions.system_fingerprint,
10461046
reasoning_content=reasoning_content,
1047+
finish_reason=choice.finish_reason,
10471048
)
10481049
)
10491050
return results
@@ -1081,6 +1082,7 @@ async def acompletion_iter(
10811082
outputs = []
10821083
logprobs = []
10831084
role = None
1085+
finish_reason: str | None = None
10841086
reasoning_content = []
10851087
used_model = None
10861088
async for completion in stream_completions:
@@ -1094,6 +1096,9 @@ async def acompletion_iter(
10941096
logprobs.append(logprob_content[0].logprob or 0)
10951097
outputs.append(delta.content or "")
10961098
role = delta.role or role
1099+
# The usage-only chunk (when include_usage=True) has finish_reason=None,
1100+
# so retain the last non-None finish_reason value
1101+
finish_reason = choice.finish_reason or finish_reason
10971102
if hasattr(delta, "reasoning_content"):
10981103
reasoning_content.append(delta.reasoning_content or "")
10991104
text = "".join(outputs)
@@ -1123,6 +1128,7 @@ async def acompletion_iter(
11231128
cache_read_tokens=cache_read,
11241129
cache_creation_tokens=cache_creation,
11251130
cost=cost,
1131+
finish_reason=finish_reason,
11261132
)
11271133

11281134
if text:

packages/lmi/src/lmi/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ class LLMResult(BaseModel):
123123
reasoning_content: str | None = Field(
124124
default=None, description="Reasoning content from LLMs such as DeepSeek-R1."
125125
)
126+
finish_reason: str | None = Field(
127+
default=None,
128+
description=(
129+
"The reason the model stopped generating tokens, or None if not available."
130+
),
131+
examples=["stop", "length", "tool_calls", "refusal"],
132+
)
126133

127134
def __str__(self) -> str:
128135
return self.text or ""

packages/lmi/tests/test_llms.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,14 @@ def accum(x) -> None:
381381
assert completion.completion_count > 0
382382
assert str(completion) == "".join(outputs)
383383
assert completion.cost > 0
384+
assert completion.finish_reason == "stop"
384385

385386
completion = await llm.call_single(
386387
messages=messages,
387388
)
388389
assert completion.seconds_to_last_token > 0
389390
assert completion.cost > 0
391+
assert completion.finish_reason == "stop"
390392

391393
# check with mixed callbacks
392394
async def ac(x) -> None:
@@ -507,13 +509,15 @@ def _build_mock_completion(
507509
delta_content: str = "",
508510
delta_reasoning_content: str = "hmmm",
509511
delta_role: str = "assistant",
512+
finish_reason: str = "unknown",
510513
usage: Any = None,
511514
) -> Mock:
512515
return Mock(
513516
model=model,
514517
choices=[
515518
Mock(
516519
logprobs=logprobs,
520+
finish_reason=finish_reason,
517521
delta=Mock(
518522
content=delta_content,
519523
reasoning_content=delta_reasoning_content,
@@ -545,9 +549,10 @@ def _build_mock_completion(
545549
logprobs=Mock(content=[Mock(logprob=-0.5)])
546550
)
547551

548-
# Mock completion with usage info
552+
# Mock completion with usage info (final chunk has finish_reason)
549553
mock_completion_usage = _build_mock_completion(
550-
usage=Mock(prompt_tokens=10, completion_tokens=5)
554+
usage=Mock(prompt_tokens=10, completion_tokens=5),
555+
finish_reason="stop",
551556
)
552557

553558
# Create async generator that yields mock completions
@@ -576,6 +581,7 @@ async def mock_stream_iter(): # noqa: RUF029
576581
assert result.logprob == -0.5
577582
assert result.prompt_count == 10
578583
assert result.completion_count == 5
584+
assert result.finish_reason == "stop"
579585

580586

581587
class DummyOutputSchema(BaseModel):
@@ -787,6 +793,7 @@ async def test_single_completion(self, model_name: str) -> None:
787793
assert len(result.messages) == 1
788794
assert result.messages[0].content
789795
assert not hasattr(result.messages[0], "tool_calls"), "Expected normal message"
796+
assert result.finish_reason == "stop"
790797

791798
model = self.MODEL_CLS(name=model_name, config={"n": 2})
792799
result = await model.call_single(messages)
@@ -795,6 +802,7 @@ async def test_single_completion(self, model_name: str) -> None:
795802
assert len(result.messages) == 1
796803
assert result.messages[0].content
797804
assert not hasattr(result.messages[0], "tool_calls"), "Expected normal message"
805+
assert result.finish_reason == "stop"
798806

799807
@pytest.mark.asyncio
800808
@pytest.mark.vcr
@@ -857,17 +865,17 @@ def double(x: int) -> int:
857865
messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL
858866
)
859867
assert isinstance(results, list)
860-
assert isinstance(results[0].messages, list)
861-
862-
tool_message = results[0].messages[0]
863-
868+
(result,) = results
869+
assert isinstance(result.messages, list)
870+
tool_message = result.messages[0]
864871
assert isinstance(tool_message, ToolRequestMessage), (
865872
"It should have selected a tool"
866873
)
867874
assert not tool_message.content
868875
assert tool_message.tool_calls[0].function.arguments["x"] == 8, (
869876
"LLM failed in select the correct tool or arguments"
870877
)
878+
assert result.finish_reason == "tool_calls"
871879

872880
# Simulate the observation
873881
observation = ToolResponseMessage(
@@ -882,9 +890,11 @@ def double(x: int) -> int:
882890
messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL
883891
)
884892
assert isinstance(results, list)
885-
assert isinstance(results[0].messages, list)
886-
assert results[0].messages[0].content
887-
assert "16" in results[0].messages[0].content
893+
(result,) = results
894+
assert isinstance(result.messages, list)
895+
assert result.messages[0].content
896+
assert "16" in result.messages[0].content
897+
assert result.finish_reason == "stop"
888898

889899
@pytest.mark.asyncio
890900
@pytest.mark.parametrize(
@@ -1201,6 +1211,7 @@ def mock_router_method(_self, _override_config=None):
12011211

12021212
assert results.text == "I'm sorry, but I can't assist with that request."
12031213
assert results.model == CommonLLMNames.GPT_41.value
1214+
assert results.finish_reason == "stop"
12041215
assert "the llm request was refused" in caplog.text.lower()
12051216
assert "attempting to fallback" in caplog.text.lower()
12061217

0 commit comments

Comments
 (0)