Skip to content

Commit 0aff197

Browse files
authored
Merge pull request #13 from agntcy/feat/iomapper-llamaindex
Feat/iomapper llamaindex
2 parents c9fd94e + 783653f commit 0aff197

29 files changed

+2060
-242
lines changed

Diff for: Makefile

+6-1
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ setup: set_python_env
4343

4444
# ============================
4545

46-
setup_test: set_python_env
46+
setup_test:
4747
@poetry install --with=test --all-extras
4848

4949
test: setup_test
5050
poetry run pytest -vvrx
51+
test_manifest: setup_test
52+
poetry run pytest tests/test_agent_iomapper_from_manifest.py
53+
54+
test_langgraph_agent: setup_test
55+
poetry run pytest tests/test_langgraph_agent_iomapper.py

Diff for: agntcy_iomapper/base/__init__.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0
3-
from .agent_iomapper import *
4-
from .base import *
3+
from .agent_iomapper import (
4+
AgentIOMapper,
5+
AgentIOMapperConfig,
6+
AgentIOMapperInput,
7+
AgentIOMapperOutput,
8+
)
9+
from .base import (
10+
ArgumentsDescription,
11+
BaseIOMapper,
12+
BaseIOMapperConfig,
13+
BaseIOMapperInput,
14+
BaseIOMapperOutput,
15+
)
16+
17+
__all__ = [
18+
"AgentIOMapperConfig",
19+
"ArgumentsDescription",
20+
"AgentIOMapperInput",
21+
"AgentIOMapperOutput",
22+
"AgentIOMapper",
23+
"BaseIOMapperInput",
24+
"BaseIOMapperOutput",
25+
"BaseIOMapperConfig",
26+
"BaseIOMapper",
27+
]

Diff for: agntcy_iomapper/base/agent_iomapper.py

+2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def _get_output(
124124
f'{{"data": {json.dumps(outputs)} }}'
125125
)
126126

127+
logger.debug(f"{outputs}")
128+
127129
# Check if data is returned in JSON markdown text
128130
matches = self._json_search_pattern.findall(outputs)
129131
if matches:

Diff for: agntcy_iomapper/base/utils.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import Any, Dict, List, Optional, Type
2+
3+
from openapi_pydantic import Schema
4+
5+
6+
def _create_type_from_schema(
7+
json_schema: Dict[str, Any], pick_fields: List[str]
8+
) -> Optional[Type]:
9+
"""
10+
Creates a new Pydantic model with only the specified fields from a JSON schema.
11+
12+
Args:
13+
schema: The JSON schema of the original object.
14+
fields: A list of field names to include in the new model.
15+
16+
Returns:
17+
A new Pydantic model class containing only the specified fields.
18+
"""
19+
defs = json_schema.get("$defs", {})
20+
properties = json_schema.get("properties", {})
21+
filtered_properties = {}
22+
23+
for path in pick_fields:
24+
parts = path.split(".")
25+
root_item = parts[0]
26+
prop = properties[root_item]
27+
28+
if "anyOf" in prop:
29+
final_schema = []
30+
filtered_properties[root_item] = {}
31+
_extract_schema(prop, defs, final_schema)
32+
filtered_properties[root_item]["anyOf"] = final_schema
33+
34+
elif "items" in prop:
35+
final_schema = []
36+
_extract_schema(prop, defs, final_schema)
37+
filtered_properties[root_item] = {"type": "array", "items": final_schema}
38+
39+
elif "$ref" in prop:
40+
resolved_def = resolve_ref(prop.get("$ref"), defs)
41+
filtered_properties[root_item] = resolved_def
42+
43+
else:
44+
final_schema = []
45+
filtered_properties[root_item] = {}
46+
_extract_schema(prop, defs, final_schema)
47+
filtered_properties[root_item] = final_schema
48+
# TODO - remove fields not selected from the output
49+
# filtered_properties = _refine_schema(filtered_properties, pick_fields)
50+
51+
return Schema.model_validate(filtered_properties)
52+
53+
54+
def _extract_schema(json_schema, defs, schema):
55+
if "anyOf" in json_schema:
56+
for val in json_schema.get("anyOf"):
57+
_extract_schema(val, defs, schema)
58+
elif "items" in json_schema:
59+
item = json_schema.get("items")
60+
_extract_schema(item, defs, schema)
61+
elif "$ref" in json_schema:
62+
ref = json_schema.get("$ref")
63+
schema.append(resolve_ref(ref, defs))
64+
elif "type" in json_schema:
65+
schema.append(json_schema)
66+
else:
67+
return
68+
69+
70+
def _extract_nested_fields(data: Any, fields: List[str]) -> dict:
71+
"""Extracts specified fields from a potentially nested data structure
72+
Args:
73+
data: The input data (can be any type)
74+
fields: A list of fields path (e.g.. "fielda.fieldb")
75+
Returns:
76+
A dictionary containing the extracted fields and their values.
77+
Returns empty dictionary if there are errors
78+
"""
79+
if not fields:
80+
return {}
81+
82+
results = {}
83+
84+
for field_path in fields:
85+
try:
86+
value = _get_nested_value(data, field_path)
87+
results[field_path] = value
88+
except (KeyError, TypeError, AttributeError, ValueError) as e:
89+
print(f"Error extracting field {field_path}: {e}")
90+
return results
91+
92+
93+
def _get_nested_value(data: Any, field_path: str) -> Optional[Any]:
94+
"""
95+
Recursively retrieves a value from a nested data structure
96+
"""
97+
current = data
98+
parts = field_path.split(".")
99+
100+
for part in parts:
101+
if isinstance(current, dict):
102+
current = current[part]
103+
elif isinstance(current, list) and part.isdigit():
104+
current = current[int(part)]
105+
elif hasattr(current, part):
106+
current = getattr(current, part)
107+
else:
108+
current = None
109+
110+
return current
111+
112+
113+
def resolve_ref(ref, current_defs):
114+
ref_parts = ref.split("/")
115+
current = current_defs
116+
for part in ref_parts[2:]:
117+
current = current.get(part)
118+
return current
119+
120+
121+
def _refine_schema(schema, paths):
122+
filtered_schema = {}
123+
print(f"{paths}-{schema}")
124+
125+
for path in paths:
126+
path_parts = path.split(".")
127+
if path_parts[0] in schema:
128+
key = path_parts[0]
129+
root_schema = schema.get(key)
130+
properties = {}
131+
if "anyOf" in root_schema:
132+
filtered_schema[key] = {}
133+
sub_schemas = root_schema.get("anyOf")
134+
135+
for sub_schema in sub_schemas:
136+
if "properties" in sub_schema:
137+
curr_properties = sub_schema.get("properties")
138+
properties = _filter_properties(
139+
sub_schema, curr_properties, path_parts[1:]
140+
)
141+
if key in filtered_schema:
142+
if "anyOf" not in filtered_schema:
143+
filtered_schema[key]["anyOf"] = [
144+
{"properties": properties}
145+
]
146+
else:
147+
filtered_schema[key]["anyOf"]
148+
149+
elif "items" in root_schema:
150+
sub_schemas = root_schema.get("items")
151+
elif "properties" in root_schema:
152+
curr_properties = root_schema.get("properties")
153+
properties = _filter_properties(
154+
root_schema, curr_properties, path_parts[1:]
155+
)
156+
print(f" after {paths}-{filtered_schema}")
157+
return filtered_schema
158+
159+
160+
def _filter_properties(schema, properties, paths):
161+
filtered_schema = {}
162+
163+
if len(paths) == 0:
164+
return schema
165+
166+
for path in paths:
167+
if path in properties:
168+
filtered_schema[path] = properties.get(path)
169+
if "properties" in filtered_schema[path]:
170+
return _filter_properties(
171+
schema, filtered_schema[path].get("properties"), paths[1:]
172+
)
173+
elif path in schema:
174+
filtered_schema[path] = schema.get(path)
175+
else:
176+
continue
177+
178+
return filtered_schema

Diff for: agntcy_iomapper/langgraph/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0"
33

4+
from .create_langraph_iomapper import create_langraph_iomapper, io_mapper_node
45
from .langgraph import (
56
LangGraphIOMapper,
67
LangGraphIOMapperConfig,
78
LangGraphIOMapperInput,
89
LangGraphIOMapperOutput,
910
)
1011

11-
from .create_langraph_iomapper import create_langraph_iomapper
12+
__all__ = [
13+
"create_langraph_iomapper",
14+
"io_mapper_node",
15+
"LangGraphIOMapper",
16+
"LangGraphIOMapperConfig",
17+
"LangGraphIOMapperInput",
18+
"LangGraphIOMapperOutput",
19+
]

Diff for: agntcy_iomapper/langgraph/create_langraph_iomapper.py

+80
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0"
3+
from typing import Any
4+
35
from langchain_core.runnables import Runnable
6+
from pydantic import BaseModel
7+
8+
from agntcy_iomapper.base import AgentIOMapperInput, ArgumentsDescription
9+
from agntcy_iomapper.base.utils import _create_type_from_schema, _extract_nested_fields
410

511
from .langgraph import (
612
LangGraphIOMapper,
@@ -20,3 +26,77 @@ def create_langraph_iomapper(
2026
A runnable representing an agent. It returns as output the mapping result
2127
"""
2228
return LangGraphIOMapper(config).as_runnable()
29+
30+
31+
def io_mapper_node(data: Any, config: dict) -> Runnable:
32+
"""Creates a langgraph node
33+
Args:
34+
data: represents the state of the graph
35+
config: is the runnable config inject by langgraph framework
36+
metadata has the following structure
37+
- input_fields: Required, it expects an array of fields to be used in the mapping, this fields must be in the state eg: ["name", "address.street"]
38+
- input_fields: Required, it expects an array of fields to include in the mapping result, eg: ["full_name", "full_address"]
39+
- input_schema: Optional, defines the schema of the input_data, this is useful if your state is not a pydantic model, not required if you state is a pydantic model.
40+
- output_schema: Optional, defines the schema of the output_data, this is useful if your output is not a pydantic model, not required if your output model is a pydantic model.
41+
To understand better how and when to use any of these options check the examples folder
42+
Returns:
43+
A runnable, that can be included in the langgraph node
44+
"""
45+
metadata = config.get("metadata", None)
46+
if not metadata:
47+
return ValueError(
48+
"A metadata must be present with at least the configuration for input_fields and output_fields"
49+
)
50+
if not data:
51+
return ValueError("data is required. Invalid or no data was passed")
52+
53+
input_fields = metadata.get("input_fields", None)
54+
if not input_fields:
55+
return ValueError("input_fields not found in the metadata")
56+
57+
output_fields = metadata.get("output_fields", None)
58+
if not output_fields:
59+
return ValueError("output_fields not found in the metadata")
60+
61+
configurable = config.get("configurable", None)
62+
if not configurable:
63+
return ValueError(
64+
"to use io_mapper_node an llm config must be passed via langgraph runnable config"
65+
)
66+
67+
llm = configurable.get("llm", None)
68+
if not llm:
69+
return ValueError(
70+
"to use io_mapper_node an llm config must be passed via langgraph runnable config"
71+
)
72+
input_type = None
73+
output_type = None
74+
75+
if isinstance(data, BaseModel):
76+
input_schema = data.model_json_schema()
77+
else:
78+
# Read the optional fields
79+
input_schema = metadata["input_schema"]
80+
output_schema = metadata["output_schema"]
81+
if not input_schema or not output_schema:
82+
raise ValueError(
83+
"input_schema, and or output_schema are missing from the metadata, for a better accuracy you are required to provide them in this scenario"
84+
)
85+
86+
output_type = _create_type_from_schema(input_schema, output_fields)
87+
input_type = _create_type_from_schema(input_schema, input_fields)
88+
89+
data_to_be_mapped = _extract_nested_fields(data, fields=input_fields)
90+
91+
input = AgentIOMapperInput(
92+
input=ArgumentsDescription(
93+
json_schema=input_type,
94+
),
95+
output=ArgumentsDescription(
96+
json_schema=output_type,
97+
),
98+
data=data_to_be_mapped,
99+
)
100+
101+
iomapper_config = LangGraphIOMapperConfig(llm=llm)
102+
return LangGraphIOMapper(iomapper_config, input).as_runnable()

Diff for: agntcy_iomapper/langgraph/langgraph.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424

2525
class LangGraphIOMapperConfig(AgentIOMapperConfig):
26-
llm: BaseChatModel | str = Field(
27-
default="anthropic:claude-3-5-sonnet-latest",
28-
description="Model to use for translation as LangChain description or model class.",
26+
llm: BaseChatModel | str = (
27+
Field(
28+
...,
29+
description="Model to use for translation as LangChain description or model class.",
30+
),
2931
)
3032

3133

@@ -67,20 +69,28 @@ async def _ainvoke(
6769

6870

6971
class LangGraphIOMapper:
70-
def __init__(self, config: LangGraphIOMapperConfig):
72+
def __init__(
73+
self,
74+
config: LangGraphIOMapperConfig,
75+
input: LangGraphIOMapperInput | None = None,
76+
):
7177
self._iomapper = _LangGraphAgentIOMapper(config)
78+
self._input = input
7279

7380
async def ainvoke(self, state: dict[str, Any], config: RunnableConfig) -> dict:
74-
response = await self._iomapper.ainvoke(input=state["input"], config=config)
81+
input = self._input if self._input else state["input"]
82+
response = await self._iomapper.ainvoke(input=input, config=config)
7583
if response is not None:
76-
return {"output": response}
84+
return response.data
7785
else:
7886
return {}
7987

8088
def invoke(self, state: dict[str, Any], config: RunnableConfig) -> dict:
81-
response = self._iomapper.invoke(input=state["input"], config=config)
89+
input = self._input if self._input else state["input"]
90+
response = self._iomapper.invoke(input=input, config=config)
91+
8292
if response is not None:
83-
return {"output": response}
93+
return response.data
8494
else:
8595
return {}
8696

0 commit comments

Comments
 (0)