Skip to content

Commit 418aa72

Browse files
committed
Refactor LLM chats: separate streaming logic and enforce strict typing
Major refactor of the LLM chat architecture to improve code organization, maintainability, and type safety. Key Changes: - Split `LLMChat` subclasses into distinct Non-Streaming and Streaming implementations. Streaming logic (primarily for notebooks) was complicating the core classes; this split makes primary actors more concise and less error-prone. - Moved provider-specific implementations into separate files: `openai.py` and `genai.py`. - Replaced the generic `LLMResponse` with a strictly typed version, specifically enforcing types for `tool_usage` and `token_usage`. - Updated `invoke` method to accept explicit arguments. - Migrated OpenAI integration from the `completion` API to the more user-friendly `responses` API. Testing: - Added coverage for common use cases using real APIs (tests run conditionally if environment keys are present).
1 parent 258c82c commit 418aa72

30 files changed

+2383
-1055
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Kaggle Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %% [markdown]
16+
# ---
17+
# title: Example of a game that requires tool use.
18+
# ---
19+
20+
# %%
21+
import random
22+
23+
import kaggle_benchmarks as kbench
24+
from kaggle_benchmarks.kaggle import models
25+
26+
SECRET_NUMBER = random.randint(1, 10)
27+
28+
29+
def guess_number(guess: int) -> str:
30+
"""Make a guess in the number guessing game."""
31+
if guess < SECRET_NUMBER:
32+
return "Higher"
33+
elif guess > SECRET_NUMBER:
34+
return "Lower"
35+
else:
36+
return "Correct!"
37+
38+
39+
@kbench.task(name="guess-the-number-game")
40+
def play_game(llm):
41+
prompt = "I'm thinking of a number between 1 and 10. Can you guess it?"
42+
response = llm.prompt(prompt, schema=int, tools=[guess_number])
43+
44+
for _ in range(4):
45+
if response == SECRET_NUMBER:
46+
break
47+
response = llm.prompt(response, schema=int, tools=[guess_number])
48+
49+
kbench.assertions.assert_equal(
50+
SECRET_NUMBER,
51+
response,
52+
expectation=f"LLM should have guessed the secret number. The secret number was {SECRET_NUMBER}",
53+
)
54+
55+
56+
# %%
57+
58+
llm_with_genai_api = models.load_model(
59+
model_name=kbench.llm.name,
60+
api="genai",
61+
)
62+
63+
play_game.run(llm=llm_with_genai_api)
64+
65+
# %%
66+
67+
llm_with_openai_api = models.load_model(
68+
model_name=kbench.llm.name,
69+
api="openai",
70+
)
71+
72+
play_game.run(llm_with_openai_api)
73+
74+
# %%

documentation/examples/prompt_with_tools.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

documentation/examples/use_calculator_tool.py

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414

1515
# %% [markdown]
1616
# ---
17-
# title: Manual Calculator Tool Calling
17+
# title: Calculator Tool
1818
# ---
1919
# %%
20-
import json
21-
22-
from kaggle_benchmarks import actors, assertions, llm, messages, task
20+
from kaggle_benchmarks import actors, assertions, llm, task
2321

2422
tool = actors.Actor(name="Tool", role="tool", avatar="🛠️")
2523

2624

2725
def run_simple_calculator(a: float, b: float, operator: str) -> float:
26+
"""Calculates the result of an arithmetic operation like +, -, *, or /."""
2827
if operator == "+":
2928
return a + b
3029
if operator == "-":
@@ -37,72 +36,23 @@ def run_simple_calculator(a: float, b: float, operator: str) -> float:
3736

3837

3938
@task("Calculator Tool Use")
40-
def use_calculator(
41-
llm, problem: str, expected_answer: float, stream_mode: bool = False
42-
) -> None:
43-
calculator_tool = {
44-
"type": "function",
45-
"function": {
46-
"name": "simple_calculator",
47-
"description": "Calculates the result of an arithmetic operation.",
48-
"parameters": {
49-
"type": "object",
50-
"properties": {
51-
"a": {"type": "number", "description": "The first number."},
52-
"b": {"type": "number", "description": "The second number."},
53-
"operator": {
54-
"type": "string",
55-
"description": "The operator (+, -, *, /).",
56-
},
57-
},
58-
"required": ["a", "b", "operator"],
59-
},
60-
},
61-
}
62-
llm.stream_responses = stream_mode
63-
64-
actors.user.send(problem)
65-
66-
tool_call_msg = llm.respond(tools=[calculator_tool])
67-
tool_calls = tool_call_msg.tool_calls
68-
assertions.assert_true(
69-
bool(tool_calls), "LLM was expected to call a tool, but it did not."
70-
)
71-
72-
tool_call = tool_calls[0]
73-
function_args = json.loads(tool_call["function"]["arguments"])
74-
# Removes 'signature' parameter in thinking mode.
75-
function_args.pop("signature", None)
76-
tool_result = ""
77-
try:
78-
tool_result = run_simple_calculator(**function_args)
79-
except Exception as e:
80-
tool_result = f"Error executing tool: {type(e).__name__} - {e}"
81-
82-
tool.send(
83-
messages.Message(
84-
sender=tool,
85-
content=str(tool_result),
86-
_meta={"tool_call_id": tool_call["id"]},
87-
)
39+
def use_calculator(llm, problem: str, expected_answer: float) -> None:
40+
final_answer = llm.prompt(problem, tools=[run_simple_calculator])
41+
assertions.assert_tool_was_invoked(
42+
run_simple_calculator, "LLM was expected to call a tool, but it did not."
8843
)
8944

90-
final_answer_msg = llm.respond()
91-
final_answer = final_answer_msg.content
92-
9345
assertions.assert_true(
9446
str(expected_answer) in final_answer,
9547
f"Expected '{expected_answer}' to be in the final answer, but got '{final_answer}'.",
9648
)
9749

9850

51+
# %%
52+
9953
problem = "What is 485 multiplied by 12?"
10054
expected = 485 * 12
10155

102-
# %%
103-
use_calculator.run(llm, problem=problem, expected_answer=expected, stream_mode=True)
104-
105-
# %%
106-
use_calculator.run(llm, problem=problem, expected_answer=expected, stream_mode=False)
56+
use_calculator.run(llm, problem=problem, expected_answer=expected)
10757

10858
# %%

src/kaggle_benchmarks/actors/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from kaggle_benchmarks.actors.base import Actor, assertion, system, user
15+
from kaggle_benchmarks.actors.base import Actor, Tool, assertion, system, user
16+
from kaggle_benchmarks.actors.genai import GoogleGenAI, StreamingGoogleGenAI
1617
from kaggle_benchmarks.actors.llms import LLMChat
18+
from kaggle_benchmarks.actors.openai import (
19+
ModelProxyOpenAI,
20+
OpenAIResponsesAPI,
21+
StreamingOpenAIResponsesAPI,
22+
)

src/kaggle_benchmarks/actors/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def __str__(self) -> str:
7070
return f"{self.avatar} {self.name}"
7171

7272

73+
class Tool(Actor):
74+
def __init__(self, name: str = "tool"):
75+
super().__init__(name=name, role="tool")
76+
77+
7378
system = Actor(name="System", role="system", avatar="⚙️")
7479
assertion = Actor(name="Assertion", role="system", avatar="🚨️")
7580
user = Actor(name="User", role="user", avatar="👤")

0 commit comments

Comments
 (0)