Skip to content

Commit 20f9320

Browse files
authored
Merge pull request #11 from agntcy/feat/langgraph-support
Feat/langgraph support
2 parents 6f6a814 + bbaeae5 commit 20f9320

20 files changed

+2506
-441
lines changed

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ setup: set_python_env
4444
# ============================
4545

4646
setup_test: set_python_env
47-
@poetry install --with=test
47+
@poetry install --with=test --all-extras
4848

4949
test: setup_test
5050
poetry run pytest -vvrx

Diff for: agntcy_iomapper/__init__.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0
33
# ruff: noqa: F401
4+
from .agent_iomapper import (
5+
AgentIOMapper,
6+
AgentIOMapperConfig,
7+
AgentIOMapperInput,
8+
AgentIOMapperOutput,
9+
)
410
from .base import (
511
BaseIOMapper,
6-
IOMapperInput,
7-
IOMapperOutput,
8-
IOModelSettings,
12+
BaseIOMapperConfig,
13+
BaseIOMapperInput,
14+
BaseIOMapperOutput,
915
)
10-
from .imperative import ImperativeIOMapper
11-
from .iomapper import (
12-
AgentIOMapper,
13-
IOMapperConfig,
14-
IOModelArgs,
16+
from .imperative import (
17+
ImperativeIOMapper,
18+
ImperativeIOMapperInput,
19+
ImperativeIOMapperOutput,
1520
)
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,38 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
22
# SPDX-License-Identifier: Apache-2.0
3-
import argparse
4-
import asyncio
53
import json
64
import logging
75
import re
8-
from typing import ClassVar, TypedDict
6+
from abc import abstractmethod
7+
from typing import ClassVar
98

10-
import aiofiles
119
import jsonschema
12-
from dotenv import find_dotenv, load_dotenv
1310
from jinja2 import Environment
1411
from jinja2.sandbox import SandboxedEnvironment
15-
from pydantic import BaseModel, Field, model_validator
16-
from pydantic_ai import Agent
17-
from typing_extensions import Self
12+
from pydantic import Field
1813

19-
from .base import BaseIOMapper, IOMapperInput, IOMapperOutput, IOModelSettings
20-
from .supported_agents import get_supported_agent
14+
from .base import (
15+
BaseIOMapper,
16+
BaseIOMapperConfig,
17+
BaseIOMapperInput,
18+
BaseIOMapperOutput,
19+
)
2120

2221
logger = logging.getLogger(__name__)
2322

2423

25-
class IOModelArgs(TypedDict, total=False):
26-
base_url: str
27-
api_version: str
28-
azure_endpoint: str
29-
azure_ad_token: str
30-
project: str
31-
organization: str
24+
class AgentIOMapperInput(BaseIOMapperInput):
25+
message_template: str | None = Field(
26+
max_length=4096,
27+
default=None,
28+
description="Message (user) to send to LLM to effect translation.",
29+
)
3230

3331

34-
class IOMapperConfig(BaseModel):
35-
models: dict[str, IOModelArgs] = Field(
36-
default={"azure:gpt-4o-mini": IOModelArgs()},
37-
description="LLM configuration to use for translation",
38-
)
39-
default_model: str | None = Field(
40-
default="azure:gpt-4o-mini",
41-
description="Default arguments to LLM completion function by configured model.",
42-
)
43-
default_model_settings: dict[str, IOModelSettings] = Field(
44-
default={"azure:gpt-4o-mini": IOModelSettings(seed=42, temperature=0.8)},
45-
description="LLM configuration to use for translation",
46-
)
47-
validate_json_input: bool = Field(
48-
default=False, description="Validate input against JSON schema."
49-
)
50-
validate_json_output: bool = Field(
51-
default=False, description="Validate output against JSON schema."
52-
)
32+
AgentIOMapperOutput = BaseIOMapperOutput
33+
34+
35+
class AgentIOMapperConfig(BaseIOMapperConfig):
5336
system_prompt_template: str = Field(
5437
max_length=4096,
5538
default="You are a translation machine. You translate both natural language and object formats for computers.",
@@ -61,19 +44,6 @@ class IOMapperConfig(BaseModel):
6144
description="Default user message template. This can be overridden by the message request.",
6245
)
6346

64-
@model_validator(mode="after")
65-
def _validate_obj(self) -> Self:
66-
if self.models and self.default_model not in self.models:
67-
raise ValueError(
68-
f"default model {self.default_model} not present in configured models"
69-
)
70-
# Fill out defaults to eliminate need for checking.
71-
for model_name in self.models.keys():
72-
if model_name not in self.default_model_settings:
73-
self.default_model_settings[model_name] = IOModelSettings()
74-
75-
return self
76-
7747

7848
class AgentIOMapper(BaseIOMapper):
7949
_json_search_pattern: ClassVar[re.Pattern] = re.compile(
@@ -82,11 +52,13 @@ class AgentIOMapper(BaseIOMapper):
8252

8353
def __init__(
8454
self,
85-
config: IOMapperConfig,
55+
config: AgentIOMapperConfig | None = None,
8656
jinja_env: Environment | None = None,
8757
jinja_env_async: Environment | None = None,
8858
):
89-
self.config = config
59+
if config is None:
60+
config = AgentIOMapperConfig()
61+
super().__init__(config)
9062

9163
if jinja_env is not None and jinja_env.is_async:
9264
raise ValueError("Async Jinja env passed to jinja_env argument")
@@ -136,41 +108,19 @@ def _check_jinja_env(self, enable_async: bool):
136108
self.config.message_template
137109
)
138110

139-
def _get_render_env(self, input: IOMapperInput) -> dict[str, str]:
111+
def _get_render_env(self, input: AgentIOMapperInput) -> dict[str, str]:
140112
return {
141113
"input": input.input,
142114
"output": input.output,
143115
"data": input.data,
144116
}
145117

146-
def _get_model_settings(self, input: IOMapperInput):
147-
model_name = input.model or self.config.default_model
148-
if model_name not in self.config.models:
149-
raise ValueError(f"requested model {model_name} not found")
150-
elif input.model_settings is None:
151-
return self.config.default_model_settings[model_name]
152-
else:
153-
model_settings = self.config.default_model_settings[model_name].copy()
154-
model_settings.update(input.model_settings)
155-
return model_settings
156-
157-
def _get_agent(
158-
self, is_async: bool, input: IOMapperInput, system_prompt: str
159-
) -> Agent:
160-
model_name = input.model or self.config.default_model
161-
if model_name not in self.config.models:
162-
raise ValueError(f"requested model {model_name} not found")
163-
164-
return get_supported_agent(
165-
model_name,
166-
model_args=self.config.models[model_name],
167-
system_prompt=system_prompt,
168-
)
169-
170-
def _get_output(self, input: IOMapperInput, outputs: str) -> IOMapperOutput:
118+
def _get_output(
119+
self, input: AgentIOMapperInput, outputs: str
120+
) -> AgentIOMapperOutput:
171121
if input.output.json_schema is None:
172122
# If there is no schema, quote the chars for JSON.
173-
return IOMapperOutput.model_validate_json(
123+
return AgentIOMapperOutput.model_validate_json(
174124
f'{{"data": {json.dumps(outputs)} }}'
175125
)
176126

@@ -179,9 +129,9 @@ def _get_output(self, input: IOMapperInput, outputs: str) -> IOMapperOutput:
179129
if matches:
180130
outputs = matches[-1]
181131

182-
return IOMapperOutput.model_validate_json(f'{{"data": {outputs} }}')
132+
return AgentIOMapperOutput.model_validate_json(f'{{"data": {outputs} }}')
183133

184-
def _validate_input(self, input: IOMapperInput) -> None:
134+
def _validate_input(self, input: AgentIOMapperInput) -> None:
185135
if self.config.validate_json_input and input.input.json_schema is not None:
186136
jsonschema.validate(
187137
instance=input.data,
@@ -190,7 +140,9 @@ def _validate_input(self, input: IOMapperInput) -> None:
190140
),
191141
)
192142

193-
def _validate_output(self, input: IOMapperInput, output: IOMapperOutput) -> None:
143+
def _validate_output(
144+
self, input: AgentIOMapperInput, output: AgentIOMapperOutput
145+
) -> None:
194146
if self.config.validate_json_output and input.output.json_schema is not None:
195147
output_schema = input.output.json_schema.model_dump(
196148
exclude_none=True, mode="json"
@@ -201,34 +153,46 @@ def _validate_output(self, input: IOMapperInput, output: IOMapperOutput) -> None
201153
schema=output_schema,
202154
)
203155

204-
def invoke(self, input: IOMapperInput) -> IOMapperOutput:
156+
def invoke(self, input: AgentIOMapperInput, **kwargs) -> AgentIOMapperOutput:
205157
self._validate_input(input)
206158
self._check_jinja_env(False)
207159
render_env = self._get_render_env(input)
208160
system_prompt = self.prompt_template.render(render_env)
209-
agent = self._get_agent(False, input, system_prompt)
210161

211162
if input.message_template is not None:
212163
logging.info(f"User template supplied on input: {input.message_template}")
213164
user_template = self.jinja_env.from_string(input.message_template)
214165
else:
215166
user_template = self.user_template
216-
response = agent.run_sync(
217-
user_prompt=user_template.render(render_env),
218-
model_settings=self._get_model_settings(input),
167+
user_prompt = user_template.render(render_env)
168+
169+
outputs = self._invoke(
170+
input,
171+
messages=[
172+
{"role": "system", "content": system_prompt},
173+
{"role": "user", "content": user_prompt},
174+
],
175+
**kwargs,
219176
)
220-
outputs = response.data
221177
logging.debug(f"The LLM returned: {outputs}")
222178
output = self._get_output(input, outputs)
223179
self._validate_output(input, output)
224180
return output
225181

226-
async def ainvoke(self, input: IOMapperInput) -> IOMapperOutput:
182+
@abstractmethod
183+
def _invoke(
184+
self, input: AgentIOMapperInput, messages: list[dict[str, str]], **kwargs
185+
) -> str:
186+
"""Invoke internal model to process messages.
187+
Args:
188+
messages: the messages to send to the LLM
189+
"""
190+
191+
async def ainvoke(self, input: AgentIOMapperInput, **kwargs) -> AgentIOMapperOutput:
227192
self._validate_input(input)
228193
self._check_jinja_env(True)
229194
render_env = self._get_render_env(input)
230195
system_prompt = await self.prompt_template_async.render_async(render_env)
231-
agent = self._get_agent(True, input, system_prompt)
232196

233197
if input.message_template is not None:
234198
logging.info(f"User template supplied on input: {input.message_template}")
@@ -237,53 +201,26 @@ async def ainvoke(self, input: IOMapperInput) -> IOMapperOutput:
237201
)
238202
else:
239203
user_template_async = self.user_template_async
240-
response = await agent.run(
241-
user_prompt=await user_template_async.render_async(render_env),
242-
model_settings=self._get_model_settings(input),
204+
user_prompt = await user_template_async.render_async(render_env)
205+
206+
outputs = await self._ainvoke(
207+
input,
208+
messages=[
209+
{"role": "system", "content": system_prompt},
210+
{"role": "user", "content": user_prompt},
211+
],
212+
**kwargs,
243213
)
244-
outputs = response.data
245214
logging.debug(f"The LLM returned: {outputs}")
246215
output = self._get_output(input, outputs)
247216
self._validate_output(input, output)
248217
return output
249218

250-
251-
async def main():
252-
parser = argparse.ArgumentParser()
253-
parser.add_argument("--inputfile", help="Inputfile", required=True)
254-
parser.add_argument("--configfile", help="Configuration file", required=True)
255-
parser.add_argument("--outputfile", help="Output file", required=True)
256-
args = parser.parse_args()
257-
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
258-
259-
jinja_env = SandboxedEnvironment(
260-
loader=None,
261-
enable_async=True,
262-
autoescape=False,
263-
)
264-
265-
async with aiofiles.open(args.configfile, "r") as fp:
266-
configs = await fp.read()
267-
268-
config = IOMapperConfig.model_validate_json(configs)
269-
logging.info(f"Loaded config from {args.configfile}: {config.model_dump_json()}")
270-
271-
async with aiofiles.open(args.inputfile, "r") as fp:
272-
inputs = await fp.read()
273-
274-
input = IOMapperInput.model_validate_json(inputs)
275-
logging.info(f"Loaded input from {args.inputfile}: {input.model_dump_json()}")
276-
277-
p = AgentIOMapper(config, jinja_env)
278-
output = await p.ainvoke(input)
279-
outputs = output.model_dump_json()
280-
281-
async with aiofiles.open(args.outputfile, "w") as fp:
282-
await fp.write(outputs)
283-
284-
logging.info(f"Dumped output to {args.outputfile}: {outputs}")
285-
286-
287-
if __name__ == "__main__":
288-
load_dotenv(dotenv_path=find_dotenv(usecwd=True))
289-
asyncio.run(main())
219+
@abstractmethod
220+
async def _ainvoke(
221+
self, input: AgentIOMapperInput, messages: list[dict[str, str]], **kwargs
222+
) -> str:
223+
"""Async invoke internal model to process messages.
224+
Args:
225+
messages: the messages to send to the LLM
226+
"""

0 commit comments

Comments
 (0)