Skip to content

Commit fa3ddef

Browse files
authored
feat: integrate bedrock converse with tool call block (#20099)
1 parent 60115f3 commit fa3ddef

File tree

5 files changed

+3153
-1966
lines changed

5 files changed

+3153
-1966
lines changed

llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MessageRole,
2525
TextBlock,
2626
ThinkingBlock,
27+
ToolCallBlock,
2728
)
2829
from llama_index.core.bridge.pydantic import Field, PrivateAttr
2930
from llama_index.core.callbacks import CallbackManager
@@ -365,18 +366,17 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
365366
def _get_content_and_tool_calls(
366367
self, response: Optional[Dict[str, Any]] = None, content: Dict[str, Any] = None
367368
) -> Tuple[
368-
List[Union[TextBlock, ThinkingBlock]], Dict[str, Any], List[str], List[str]
369+
List[Union[TextBlock, ThinkingBlock, ToolCallBlock]], List[str], List[str]
369370
]:
370371
assert response is not None or content is not None, (
371372
f"Either response or content must be provided. Got response: {response}, content: {content}"
372373
)
373374
assert response is None or content is None, (
374375
f"Only one of response or content should be provided. Got response: {response}, content: {content}"
375376
)
376-
tool_calls = []
377377
tool_call_ids = []
378378
status = []
379-
blocks = []
379+
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
380380
if content is not None:
381381
content_list = [content]
382382
else:
@@ -401,15 +401,21 @@ def _get_content_and_tool_calls(
401401
tool_usage["toolUseId"] = content_block["toolUseId"]
402402
if "name" not in tool_usage:
403403
tool_usage["name"] = content_block["name"]
404-
tool_calls.append(tool_usage)
404+
blocks.append(
405+
ToolCallBlock(
406+
tool_name=tool_usage.get("name", ""),
407+
tool_call_id=tool_usage.get("toolUseId"),
408+
tool_kwargs=tool_usage.get("input", {}),
409+
)
410+
)
405411
if tool_result := content_block.get("toolResult", None):
406412
for tool_result_content in tool_result["content"]:
407413
if text := tool_result_content.get("text", None):
408414
text_content += text
409415
tool_call_ids.append(tool_result_content.get("toolUseId", ""))
410416
status.append(tool_result.get("status", ""))
411417

412-
return blocks, tool_calls, tool_call_ids, status
418+
return blocks, tool_call_ids, status
413419

414420
@llm_chat_callback()
415421
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
@@ -436,16 +442,13 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
436442
**all_kwargs,
437443
)
438444

439-
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
440-
response
441-
)
445+
blocks, tool_call_ids, status = self._get_content_and_tool_calls(response)
442446

443447
return ChatResponse(
444448
message=ChatMessage(
445449
role=MessageRole.ASSISTANT,
446450
blocks=blocks,
447451
additional_kwargs={
448-
"tool_calls": tool_calls,
449452
"tool_call_id": tool_call_ids,
450453
"status": status,
451454
},
@@ -540,7 +543,7 @@ def gen() -> ChatResponseGen:
540543
current_tool_call, tool_use_delta
541544
)
542545

543-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
546+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
544547
TextBlock(text=content.get("text", ""))
545548
]
546549
if thinking != "":
@@ -553,13 +556,21 @@ def gen() -> ChatResponseGen:
553556
},
554557
),
555558
)
559+
if tool_calls:
560+
for tool_call in tool_calls:
561+
blocks.append(
562+
ToolCallBlock(
563+
tool_kwargs=tool_call.get("input", {}),
564+
tool_name=tool_call.get("name", ""),
565+
tool_call_id=tool_call.get("toolUseId"),
566+
)
567+
)
556568

557569
yield ChatResponse(
558570
message=ChatMessage(
559571
role=role,
560572
blocks=blocks,
561573
additional_kwargs={
562-
"tool_calls": tool_calls,
563574
"tool_call_id": [
564575
tc.get("toolUseId", "") for tc in tool_calls
565576
],
@@ -579,7 +590,7 @@ def gen() -> ChatResponseGen:
579590
# Add to our list of tool calls
580591
tool_calls.append(current_tool_call)
581592

582-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
593+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
583594
TextBlock(text=content.get("text", ""))
584595
]
585596
if thinking != "":
@@ -593,12 +604,21 @@ def gen() -> ChatResponseGen:
593604
),
594605
)
595606

607+
if tool_calls:
608+
for tool_call in tool_calls:
609+
blocks.append(
610+
ToolCallBlock(
611+
tool_kwargs=tool_call.get("input", {}),
612+
tool_name=tool_call.get("name", ""),
613+
tool_call_id=tool_call.get("toolUseId"),
614+
)
615+
)
616+
596617
yield ChatResponse(
597618
message=ChatMessage(
598619
role=role,
599620
blocks=blocks,
600621
additional_kwargs={
601-
"tool_calls": tool_calls,
602622
"tool_call_id": [
603623
tc.get("toolUseId", "") for tc in tool_calls
604624
],
@@ -615,7 +635,7 @@ def gen() -> ChatResponseGen:
615635
# Handle metadata event - this contains the final token usage
616636
if usage := metadata.get("usage"):
617637
# Yield a final response with correct token usage
618-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
638+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
619639
TextBlock(text=content.get("text", ""))
620640
]
621641
if thinking != "":
@@ -628,13 +648,21 @@ def gen() -> ChatResponseGen:
628648
},
629649
),
630650
)
651+
if tool_calls:
652+
for tool_call in tool_calls:
653+
blocks.append(
654+
ToolCallBlock(
655+
tool_kwargs=tool_call.get("input", {}),
656+
tool_name=tool_call.get("name", ""),
657+
tool_call_id=tool_call.get("toolUseId"),
658+
)
659+
)
631660

632661
yield ChatResponse(
633662
message=ChatMessage(
634663
role=role,
635664
blocks=blocks,
636665
additional_kwargs={
637-
"tool_calls": tool_calls,
638666
"tool_call_id": [
639667
tc.get("toolUseId", "") for tc in tool_calls
640668
],
@@ -685,16 +713,13 @@ async def achat(
685713
**all_kwargs,
686714
)
687715

688-
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
689-
response
690-
)
716+
blocks, tool_call_ids, status = self._get_content_and_tool_calls(response)
691717

692718
return ChatResponse(
693719
message=ChatMessage(
694720
role=MessageRole.ASSISTANT,
695721
blocks=blocks,
696722
additional_kwargs={
697-
"tool_calls": tool_calls,
698723
"tool_call_id": tool_call_ids,
699724
"status": status,
700725
},
@@ -789,7 +814,7 @@ async def gen() -> ChatResponseAsyncGen:
789814
current_tool_call = join_two_dicts(
790815
current_tool_call, tool_use_delta
791816
)
792-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
817+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
793818
TextBlock(text=content.get("text", ""))
794819
]
795820
if thinking != "":
@@ -803,12 +828,21 @@ async def gen() -> ChatResponseAsyncGen:
803828
),
804829
)
805830

831+
if tool_calls:
832+
for tool_call in tool_calls:
833+
blocks.append(
834+
ToolCallBlock(
835+
tool_kwargs=tool_call.get("input", {}),
836+
tool_name=tool_call.get("name", ""),
837+
tool_call_id=tool_call.get("toolUseId"),
838+
)
839+
)
840+
806841
yield ChatResponse(
807842
message=ChatMessage(
808843
role=role,
809844
blocks=blocks,
810845
additional_kwargs={
811-
"tool_calls": tool_calls,
812846
"tool_call_id": [
813847
tc.get("toolUseId", "") for tc in tool_calls
814848
],
@@ -828,7 +862,7 @@ async def gen() -> ChatResponseAsyncGen:
828862
# Add to our list of tool calls
829863
tool_calls.append(current_tool_call)
830864

831-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
865+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
832866
TextBlock(text=content.get("text", ""))
833867
]
834868
if thinking != "":
@@ -842,12 +876,21 @@ async def gen() -> ChatResponseAsyncGen:
842876
),
843877
)
844878

879+
if tool_calls:
880+
for tool_call in tool_calls:
881+
blocks.append(
882+
ToolCallBlock(
883+
tool_kwargs=tool_call.get("input", {}),
884+
tool_name=tool_call.get("name", ""),
885+
tool_call_id=tool_call.get("toolUseId"),
886+
)
887+
)
888+
845889
yield ChatResponse(
846890
message=ChatMessage(
847891
role=role,
848892
blocks=blocks,
849893
additional_kwargs={
850-
"tool_calls": tool_calls,
851894
"tool_call_id": [
852895
tc.get("toolUseId", "") for tc in tool_calls
853896
],
@@ -864,7 +907,7 @@ async def gen() -> ChatResponseAsyncGen:
864907
# Handle metadata event - this contains the final token usage
865908
if usage := metadata.get("usage"):
866909
# Yield a final response with correct token usage
867-
blocks: List[Union[TextBlock, ThinkingBlock]] = [
910+
blocks: List[Union[TextBlock, ThinkingBlock, ToolCallBlock]] = [
868911
TextBlock(text=content.get("text", ""))
869912
]
870913
if thinking != "":
@@ -878,12 +921,21 @@ async def gen() -> ChatResponseAsyncGen:
878921
),
879922
)
880923

924+
if tool_calls:
925+
for tool_call in tool_calls:
926+
blocks.append(
927+
ToolCallBlock(
928+
tool_kwargs=tool_call.get("input", {}),
929+
tool_name=tool_call.get("name", ""),
930+
tool_call_id=tool_call.get("toolUseId"),
931+
)
932+
)
933+
881934
yield ChatResponse(
882935
message=ChatMessage(
883936
role=role,
884937
blocks=blocks,
885938
additional_kwargs={
886-
"tool_calls": tool_calls,
887939
"tool_call_id": [
888940
tc.get("toolUseId", "") for tc in tool_calls
889941
],
@@ -960,7 +1012,11 @@ def get_tool_calls_from_response(
9601012
**kwargs: Any,
9611013
) -> List[ToolSelection]:
9621014
"""Predict and call the tool."""
963-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
1015+
tool_calls = [
1016+
block
1017+
for block in response.message.blocks
1018+
if isinstance(block, ToolCallBlock)
1019+
]
9641020

9651021
if len(tool_calls) < 1:
9661022
if error_on_no_tool_call:
@@ -972,26 +1028,23 @@ def get_tool_calls_from_response(
9721028

9731029
tool_selections = []
9741030
for tool_call in tool_calls:
975-
if "toolUseId" not in tool_call or "name" not in tool_call:
976-
raise ValueError("Invalid tool call.")
977-
9781031
# handle empty inputs
9791032
argument_dict = {}
980-
if "input" in tool_call and isinstance(tool_call["input"], str):
1033+
if isinstance(tool_call.tool_kwargs, str):
9811034
# TODO parse_partial_json is not perfect
9821035
try:
983-
argument_dict = parse_partial_json(tool_call["input"])
1036+
argument_dict = parse_partial_json(tool_call.tool_kwargs)
9841037
except ValueError:
9851038
argument_dict = {}
986-
elif "input" in tool_call and isinstance(tool_call["input"], dict):
987-
argument_dict = tool_call["input"]
1039+
elif isinstance(tool_call.tool_kwargs, dict):
1040+
argument_dict = tool_call.tool_kwargs
9881041
else:
9891042
continue
9901043

9911044
tool_selections.append(
9921045
ToolSelection(
993-
tool_id=tool_call["toolUseId"],
994-
tool_name=tool_call["name"],
1046+
tool_id=tool_call.tool_call_id or "",
1047+
tool_name=tool_call.tool_name,
9951048
tool_kwargs=argument_dict,
9961049
)
9971050
)

0 commit comments

Comments
 (0)