Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,18 @@ class ThinkingBlock(BaseModel):
)


class ToolCallBlock(BaseModel):
block_type: Literal["tool_call"] = "tool_call"
tool_call_id: Optional[str] = Field(
default=None, description="ID of the tool call, if provided"
)
tool_name: str = Field(description="Name of the called tool")
tool_kwargs: dict[str, Any] | str = Field(
default_factory=dict, # type: ignore
description="Arguments provided to the tool, if available",
)


ContentBlock = Annotated[
Union[
TextBlock,
Expand All @@ -454,6 +466,7 @@ class ThinkingBlock(BaseModel):
CitableBlock,
CitationBlock,
ThinkingBlock,
ToolCallBlock,
],
Field(discriminator="block_type"),
]
Expand Down
5 changes: 3 additions & 2 deletions llama-index-core/llama_index/core/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CitableBlock,
CitationBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import (
BaseModel,
Expand Down Expand Up @@ -349,7 +350,7 @@ def _estimate_token_count(
] = []

for block in message_or_blocks.blocks:
if not isinstance(block, CachePoint):
if not isinstance(block, (CachePoint, ToolCallBlock)):
blocks.append(block)

# Estimate the token count for the additional kwargs
Expand All @@ -367,7 +368,7 @@ def _estimate_token_count(
blocks = []
for msg in messages:
for block in msg.blocks:
if not isinstance(block, CachePoint):
if not isinstance(block, (CachePoint, ToolCallBlock)):
blocks.append(block)

# Estimate the token count for the additional kwargs
Expand Down
17 changes: 17 additions & 0 deletions llama-index-core/tests/base/llms/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CachePoint,
CacheControl,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.bridge.pydantic import ValidationError
Expand Down Expand Up @@ -473,3 +474,19 @@ def test_thinking_block():
assert block.additional_information == {"total_thinking_tokens": 1000}
assert block.content == "hello world"
assert block.num_tokens == 100


def test_tool_call_block():
default_block = ToolCallBlock(tool_name="hello_world")
assert default_block.block_type == "tool_call"
assert default_block.tool_call_id is None
assert default_block.tool_name == "hello_world"
assert default_block.tool_kwargs == {}
custom_block = ToolCallBlock(
tool_name="hello_world",
tool_call_id="1",
tool_kwargs={"test": 1},
)
assert custom_block.tool_call_id == "1"
assert custom_block.tool_name == "hello_world"
assert custom_block.tool_kwargs == {"test": 1}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
CompletionResponseGen,
LLMMetadata,
MessageRole,
ToolCallBlock,
TextBlock,
)
from llama_index.core.bridge.pydantic import (
Field,
Expand Down Expand Up @@ -121,9 +123,15 @@ def encode(self, text: str) -> List[int]: # fmt: skip


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]


class OpenAI(FunctionCallingLLM):
Expand Down Expand Up @@ -528,6 +536,7 @@ def gen() -> ChatResponseGen:
messages=message_dicts,
**self._get_model_kwargs(stream=True, **kwargs),
):
blocks = []
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
delta = response.choices[0].delta
Expand All @@ -545,17 +554,27 @@ def gen() -> ChatResponseGen:
role = delta.role or MessageRole.ASSISTANT
content_delta = delta.content or ""
content += content_delta
blocks.append(TextBlock(text=content))

additional_kwargs = {}
if is_function:
tool_calls = update_tool_calls(tool_calls, delta.tool_calls)
if tool_calls:
additional_kwargs["tool_calls"] = tool_calls
for tool_call in tool_calls:
if tool_call.function:
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_kwargs=tool_call.function.arguments or {},
tool_name=tool_call.function.name or "",
)
)

yield ChatResponse(
message=ChatMessage(
role=role,
content=content,
blocks=blocks,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
Expand Down Expand Up @@ -785,6 +804,7 @@ async def gen() -> ChatResponseAsyncGen:
messages=message_dicts,
**self._get_model_kwargs(stream=True, **kwargs),
):
blocks = []
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
# check if the first chunk has neither content nor tool_calls
Expand Down Expand Up @@ -812,17 +832,27 @@ async def gen() -> ChatResponseAsyncGen:
role = delta.role or MessageRole.ASSISTANT
content_delta = delta.content or ""
content += content_delta
blocks.append(TextBlock(text=content))

additional_kwargs = {}
if is_function:
tool_calls = update_tool_calls(tool_calls, delta.tool_calls)
if tool_calls:
additional_kwargs["tool_calls"] = tool_calls
for tool_call in tool_calls:
if tool_call.function:
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_kwargs=tool_call.function.arguments or {},
tool_name=tool_call.function.name or "",
)
)

yield ChatResponse(
message=ChatMessage(
role=role,
content=content,
blocks=blocks,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
Expand Down Expand Up @@ -960,36 +990,71 @@ def get_tool_calls_from_response(
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])

if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]
if tool_calls:
if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
else:
return []

tool_selections = []
for tool_call in tool_calls:
# this should handle both complete and partial jsons
try:
if isinstance(tool_call.tool_kwargs, str):
argument_dict = parse_partial_json(tool_call.tool_kwargs)
else:
argument_dict = tool_call.tool_kwargs
except (ValueError, TypeError, JSONDecodeError):
argument_dict = {}

tool_selections.append(
ToolSelection(
tool_id=tool_call.tool_call_id or "",
tool_name=tool_call.tool_name,
tool_kwargs=argument_dict,
)
)
else:
return []

tool_selections = []
for tool_call in tool_calls:
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI llm")
return tool_selections
else: # keep it backward-compatible
tool_calls = response.message.additional_kwargs.get("tool_calls", [])

# this should handle both complete and partial jsons
try:
argument_dict = parse_partial_json(tool_call.function.arguments)
except (ValueError, TypeError, JSONDecodeError):
argument_dict = {}

tool_selections.append(
ToolSelection(
tool_id=tool_call.id,
tool_name=tool_call.function.name,
tool_kwargs=argument_dict,
if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
else:
return []

tool_selections = []
for tool_call in tool_calls:
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI llm")

# this should handle both complete and partial jsons
try:
argument_dict = parse_partial_json(tool_call.function.arguments)
except (ValueError, TypeError, JSONDecodeError):
argument_dict = {}

tool_selections.append(
ToolSelection(
tool_id=tool_call.id,
tool_name=tool_call.function.name,
tool_kwargs=argument_dict,
)
)
)

return tool_selections
return tool_selections

def _prepare_schema(
self, llm_kwargs: Optional[Dict[str, Any]], output_cls: Type[Model]
Expand Down
Loading
Loading