Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f721763

Browse files
committedMar 25, 2025·
feat(llama_index): add support for llama-index workflow, refactor, improve examples
1 parent 349c5ba commit f721763

31 files changed

+1182
-677
lines changed
 

‎Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ setup: set_python_env
4343

4444
# ============================
4545

46-
setup_test:
46+
setup_test: setup
4747
@poetry install --with=test --all-extras
4848

4949
test: setup_test
@@ -57,5 +57,5 @@ test_langgraph_agent: setup_test
5757

5858
test_langgraph_software: setup_test
5959
poetry run pytest tests/test_langgraph_graph_with_io_mapper.py
60-
unittest:
60+
unittest: setup_test
6161
poetry run pytest tests/unittests -vvrx

‎agntcy_iomapper/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0
33
# ruff: noqa: F401
4-
from agntcy_iomapper.agent import FieldMetadata, IOMappingAgent, IOMappingAgentMetadata
4+
from agntcy_iomapper.agent import IOMappingAgent
5+
from agntcy_iomapper.base import FieldMetadata, IOMappingAgentMetadata
56
from agntcy_iomapper.imperative import (
67
ImperativeIOMapper,
78
ImperativeIOMapperInput,

‎agntcy_iomapper/agent/__init__.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from agntcy_iomapper.agent.agent_io_mapper import IOMappingAgent
5-
from agntcy_iomapper.agent.base import AgentIOMapper
6-
from agntcy_iomapper.agent.models import (
7-
AgentIOMapperConfig,
8-
AgentIOMapperInput,
9-
AgentIOMapperOutput,
5+
from agntcy_iomapper.base.models import (
6+
BaseIOMapperConfig,
107
FieldMetadata,
118
IOMappingAgentMetadata,
129
)
1310

1411
__all__ = [
15-
"AgentIOMapper",
16-
"AgentIOMapperOutput",
17-
"AgentIOMapperConfig",
12+
"BaseIOMapperConfig",
1813
"IOMappingAgent",
1914
"IOMappingAgentMetadata",
20-
"AgentIOMapperInput",
2115
"FieldMetadata",
2216
]

‎agntcy_iomapper/agent/agent_io_mapper.py

+74-40
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,51 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import logging
5-
from typing import Any, Optional, Tuple, Union
5+
from typing import Any, Callable, List, Optional, Union
66

77
from langchain_core.language_models import BaseChatModel
88
from langchain_core.runnables import Runnable
9-
from openapi_pydantic import Schema
9+
from llama_index.core.base.llms.base import BaseLLM
10+
from llama_index.core.tools import BaseTool
11+
from llama_index.core.workflow import (
12+
Workflow,
13+
)
1014
from pydantic import BaseModel, Field, model_validator
1115
from typing_extensions import Self
1216

13-
from agntcy_iomapper.agent.models import (
17+
from agntcy_iomapper.base.models import (
1418
AgentIOMapperInput,
19+
ArgumentsDescription,
1520
FieldMetadata,
1621
IOMappingAgentMetadata,
1722
)
18-
from agntcy_iomapper.base import ArgumentsDescription
19-
from agntcy_iomapper.base.utils import create_type_from_schema, extract_nested_fields
23+
from agntcy_iomapper.base.utils import extract_nested_fields, get_io_types
24+
from agntcy_iomapper.imperative import (
25+
ImperativeIOMapper,
26+
ImperativeIOMapperInput,
27+
)
2028
from agntcy_iomapper.langgraph import LangGraphIOMapper, LangGraphIOMapperConfig
29+
from agntcy_iomapper.llamaindex.llamaindex import (
30+
LLamaIndexIOMapper,
31+
)
2132

2233
logger = logging.getLogger(__name__)
2334

2435

2536
class IOMappingAgent(BaseModel):
26-
metadata: IOMappingAgentMetadata = Field(
37+
metadata: Optional[IOMappingAgentMetadata] = Field(
2738
...,
2839
description="Details about the fields to be used in the translation and about the output",
2940
)
30-
llm: Optional[Union[BaseChatModel, str]] = (
31-
Field(
32-
None,
33-
description="Model to use for translation as LangChain description or model class.",
34-
),
41+
llm: Optional[Union[BaseChatModel, str]] = Field(
42+
None,
43+
description="Model to use for translation as LangChain description or model class.",
3544
)
3645

3746
@model_validator(mode="after")
3847
def _validate_obj(self) -> Self:
48+
if not self.metadata:
49+
return self
3950

4051
if not self.metadata.input_fields or len(self.metadata.input_fields) == 0:
4152
raise ValueError("input_fields not found in the metadata")
@@ -70,38 +81,10 @@ def _validate_obj(self) -> Self:
7081

7182
return self
7283

73-
def _get_io_types(self, data: Any) -> Tuple[Schema, Schema]:
74-
data_schema = None
75-
if isinstance(data, BaseModel):
76-
data_schema = data.model_json_schema()
77-
# If input schema is provided it overwrites the data schema
78-
input_schema = (
79-
self.metadata.input_schema if self.metadata.input_schema else data_schema
80-
)
81-
# If output schema is provided it overwrites the data schema
82-
output_schema = (
83-
self.metadata.output_schema if self.metadata.output_schema else data_schema
84-
)
85-
86-
if not input_schema or not output_schema:
87-
raise ValueError(
88-
"input_schema, and or output_schema are missing from the metadata, for a better accuracy you are required to provide them in this scenario, or we could not infer the type from the state"
89-
)
90-
91-
input_type = Schema.model_validate(
92-
create_type_from_schema(input_schema, self.metadata.input_fields)
93-
)
94-
95-
output_type = Schema.model_validate(
96-
create_type_from_schema(output_schema, self.metadata.output_fields)
97-
)
98-
99-
return (input_type, output_type)
100-
10184
def langgraph_node(self, data: Any, config: Optional[dict] = None) -> Runnable:
10285

10386
# If there is a template for the output the output_schema is going to be ignored in the translation
104-
input_type, output_type = self._get_io_types(data)
87+
input_type, output_type = get_io_types(data, self.metadata)
10588

10689
data_to_be_mapped = extract_nested_fields(
10790
data, fields=self.metadata.input_fields
@@ -132,3 +115,54 @@ def langgraph_node(self, data: Any, config: Optional[dict] = None) -> Runnable:
132115

133116
iomapper_config = LangGraphIOMapperConfig(llm=self.llm)
134117
return LangGraphIOMapper(iomapper_config, input).as_runnable()
118+
119+
def langgraph_imperative(
120+
self, data: Any, config: Optional[dict] = None
121+
) -> Runnable:
122+
123+
input_type, output_type = self._get_io_types(data, self.metadata)
124+
125+
data_to_be_mapped = extract_nested_fields(
126+
data, fields=self.metadata.input_fields
127+
)
128+
129+
input = ImperativeIOMapperInput(
130+
input=ArgumentsDescription(
131+
json_schema=input_type,
132+
),
133+
output=ArgumentsDescription(json_schema=output_type),
134+
data=data_to_be_mapped,
135+
)
136+
137+
if not self.metadata.field_mapping:
138+
raise ValueError(
139+
"In order to use imperative mapping field_mapping must be provided in the metadata"
140+
)
141+
142+
imperative_io_mapper = ImperativeIOMapper(
143+
input=input, field_mapping=self.metadata.field_mapping
144+
)
145+
return imperative_io_mapper.as_runnable()
146+
147+
@staticmethod
148+
def as_worfklow_step(workflow: Workflow) -> Callable:
149+
io_mapper_step = LLamaIndexIOMapper.llamaindex_mapper(workflow)
150+
return io_mapper_step
151+
152+
@staticmethod
153+
def as_workflow_agent(
154+
mapping_metadata: IOMappingAgentMetadata,
155+
llm: BaseLLM,
156+
name: str,
157+
description: str,
158+
can_handoff_to: Optional[List[str]] = None,
159+
tools: Optional[List[Union[BaseTool, Callable]]] = [],
160+
):
161+
return LLamaIndexIOMapper(
162+
mapping_metadata=mapping_metadata,
163+
llm=llm,
164+
tools=tools,
165+
name=name,
166+
description=description,
167+
can_handoff_to=can_handoff_to,
168+
)

0 commit comments

Comments
 (0)
Please sign in to comment.