Skip to content

Feat/langgraph support #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ setup: set_python_env
# ============================

setup_test: set_python_env
@poetry install --with=test
@poetry install --with=test --all-extras

test: setup_test
poetry run pytest -vvrx
21 changes: 13 additions & 8 deletions agntcy_iomapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: F401
from .agent_iomapper import (
AgentIOMapper,
AgentIOMapperConfig,
AgentIOMapperInput,
AgentIOMapperOutput,
)
from .base import (
BaseIOMapper,
IOMapperInput,
IOMapperOutput,
IOModelSettings,
BaseIOMapperConfig,
BaseIOMapperInput,
BaseIOMapperOutput,
)
from .imperative import ImperativeIOMapper
from .iomapper import (
AgentIOMapper,
IOMapperConfig,
IOModelArgs,
from .imperative import (
ImperativeIOMapper,
ImperativeIOMapperInput,
ImperativeIOMapperOutput,
)
203 changes: 70 additions & 133 deletions agntcy_iomapper/iomapper.py → agntcy_iomapper/agent_iomapper.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import json
import logging
import re
from typing import ClassVar, TypedDict
from abc import abstractmethod
from typing import ClassVar

import aiofiles
import jsonschema
from dotenv import find_dotenv, load_dotenv
from jinja2 import Environment
from jinja2.sandbox import SandboxedEnvironment
from pydantic import BaseModel, Field, model_validator
from pydantic_ai import Agent
from typing_extensions import Self
from pydantic import Field

from .base import BaseIOMapper, IOMapperInput, IOMapperOutput, IOModelSettings
from .supported_agents import get_supported_agent
from .base import (
BaseIOMapper,
BaseIOMapperConfig,
BaseIOMapperInput,
BaseIOMapperOutput,
)

logger = logging.getLogger(__name__)


class IOModelArgs(TypedDict, total=False):
base_url: str
api_version: str
azure_endpoint: str
azure_ad_token: str
project: str
organization: str
class AgentIOMapperInput(BaseIOMapperInput):
message_template: str | None = Field(
max_length=4096,
default=None,
description="Message (user) to send to LLM to effect translation.",
)


class IOMapperConfig(BaseModel):
models: dict[str, IOModelArgs] = Field(
default={"azure:gpt-4o-mini": IOModelArgs()},
description="LLM configuration to use for translation",
)
default_model: str | None = Field(
default="azure:gpt-4o-mini",
description="Default arguments to LLM completion function by configured model.",
)
default_model_settings: dict[str, IOModelSettings] = Field(
default={"azure:gpt-4o-mini": IOModelSettings(seed=42, temperature=0.8)},
description="LLM configuration to use for translation",
)
validate_json_input: bool = Field(
default=False, description="Validate input against JSON schema."
)
validate_json_output: bool = Field(
default=False, description="Validate output against JSON schema."
)
AgentIOMapperOutput = BaseIOMapperOutput


class AgentIOMapperConfig(BaseIOMapperConfig):
system_prompt_template: str = Field(
max_length=4096,
default="You are a translation machine. You translate both natural language and object formats for computers.",
Expand All @@ -61,19 +44,6 @@ class IOMapperConfig(BaseModel):
description="Default user message template. This can be overridden by the message request.",
)

@model_validator(mode="after")
def _validate_obj(self) -> Self:
if self.models and self.default_model not in self.models:
raise ValueError(
f"default model {self.default_model} not present in configured models"
)
# Fill out defaults to eliminate need for checking.
for model_name in self.models.keys():
if model_name not in self.default_model_settings:
self.default_model_settings[model_name] = IOModelSettings()

return self


class AgentIOMapper(BaseIOMapper):
_json_search_pattern: ClassVar[re.Pattern] = re.compile(
Expand All @@ -82,11 +52,13 @@ class AgentIOMapper(BaseIOMapper):

def __init__(
self,
config: IOMapperConfig,
config: AgentIOMapperConfig | None = None,
jinja_env: Environment | None = None,
jinja_env_async: Environment | None = None,
):
self.config = config
if config is None:
config = AgentIOMapperConfig()
super().__init__(config)

if jinja_env is not None and jinja_env.is_async:
raise ValueError("Async Jinja env passed to jinja_env argument")
Expand Down Expand Up @@ -136,41 +108,19 @@ def _check_jinja_env(self, enable_async: bool):
self.config.message_template
)

def _get_render_env(self, input: IOMapperInput) -> dict[str, str]:
def _get_render_env(self, input: AgentIOMapperInput) -> dict[str, str]:
return {
"input": input.input,
"output": input.output,
"data": input.data,
}

def _get_model_settings(self, input: IOMapperInput):
model_name = input.model or self.config.default_model
if model_name not in self.config.models:
raise ValueError(f"requested model {model_name} not found")
elif input.model_settings is None:
return self.config.default_model_settings[model_name]
else:
model_settings = self.config.default_model_settings[model_name].copy()
model_settings.update(input.model_settings)
return model_settings

def _get_agent(
self, is_async: bool, input: IOMapperInput, system_prompt: str
) -> Agent:
model_name = input.model or self.config.default_model
if model_name not in self.config.models:
raise ValueError(f"requested model {model_name} not found")

return get_supported_agent(
model_name,
model_args=self.config.models[model_name],
system_prompt=system_prompt,
)

def _get_output(self, input: IOMapperInput, outputs: str) -> IOMapperOutput:
def _get_output(
self, input: AgentIOMapperInput, outputs: str
) -> AgentIOMapperOutput:
if input.output.json_schema is None:
# If there is no schema, quote the chars for JSON.
return IOMapperOutput.model_validate_json(
return AgentIOMapperOutput.model_validate_json(
f'{{"data": {json.dumps(outputs)} }}'
)

Expand All @@ -179,9 +129,9 @@ def _get_output(self, input: IOMapperInput, outputs: str) -> IOMapperOutput:
if matches:
outputs = matches[-1]

return IOMapperOutput.model_validate_json(f'{{"data": {outputs} }}')
return AgentIOMapperOutput.model_validate_json(f'{{"data": {outputs} }}')

def _validate_input(self, input: IOMapperInput) -> None:
def _validate_input(self, input: AgentIOMapperInput) -> None:
if self.config.validate_json_input and input.input.json_schema is not None:
jsonschema.validate(
instance=input.data,
Expand All @@ -190,7 +140,9 @@ def _validate_input(self, input: IOMapperInput) -> None:
),
)

def _validate_output(self, input: IOMapperInput, output: IOMapperOutput) -> None:
def _validate_output(
self, input: AgentIOMapperInput, output: AgentIOMapperOutput
) -> None:
if self.config.validate_json_output and input.output.json_schema is not None:
output_schema = input.output.json_schema.model_dump(
exclude_none=True, mode="json"
Expand All @@ -201,34 +153,46 @@ def _validate_output(self, input: IOMapperInput, output: IOMapperOutput) -> None
schema=output_schema,
)

def invoke(self, input: IOMapperInput) -> IOMapperOutput:
def invoke(self, input: AgentIOMapperInput, **kwargs) -> AgentIOMapperOutput:
self._validate_input(input)
self._check_jinja_env(False)
render_env = self._get_render_env(input)
system_prompt = self.prompt_template.render(render_env)
agent = self._get_agent(False, input, system_prompt)

if input.message_template is not None:
logging.info(f"User template supplied on input: {input.message_template}")
user_template = self.jinja_env.from_string(input.message_template)
else:
user_template = self.user_template
response = agent.run_sync(
user_prompt=user_template.render(render_env),
model_settings=self._get_model_settings(input),
user_prompt = user_template.render(render_env)

outputs = self._invoke(
input,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
**kwargs,
)
outputs = response.data
logging.debug(f"The LLM returned: {outputs}")
output = self._get_output(input, outputs)
self._validate_output(input, output)
return output

async def ainvoke(self, input: IOMapperInput) -> IOMapperOutput:
@abstractmethod
def _invoke(
self, input: AgentIOMapperInput, messages: list[dict[str, str]], **kwargs
) -> str:
"""Invoke internal model to process messages.
Args:
messages: the messages to send to the LLM
"""

async def ainvoke(self, input: AgentIOMapperInput, **kwargs) -> AgentIOMapperOutput:
self._validate_input(input)
self._check_jinja_env(True)
render_env = self._get_render_env(input)
system_prompt = await self.prompt_template_async.render_async(render_env)
agent = self._get_agent(True, input, system_prompt)

if input.message_template is not None:
logging.info(f"User template supplied on input: {input.message_template}")
Expand All @@ -237,53 +201,26 @@ async def ainvoke(self, input: IOMapperInput) -> IOMapperOutput:
)
else:
user_template_async = self.user_template_async
response = await agent.run(
user_prompt=await user_template_async.render_async(render_env),
model_settings=self._get_model_settings(input),
user_prompt = await user_template_async.render_async(render_env)

outputs = await self._ainvoke(
input,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
**kwargs,
)
outputs = response.data
logging.debug(f"The LLM returned: {outputs}")
output = self._get_output(input, outputs)
self._validate_output(input, output)
return output


async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--inputfile", help="Inputfile", required=True)
parser.add_argument("--configfile", help="Configuration file", required=True)
parser.add_argument("--outputfile", help="Output file", required=True)
args = parser.parse_args()
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

jinja_env = SandboxedEnvironment(
loader=None,
enable_async=True,
autoescape=False,
)

async with aiofiles.open(args.configfile, "r") as fp:
configs = await fp.read()

config = IOMapperConfig.model_validate_json(configs)
logging.info(f"Loaded config from {args.configfile}: {config.model_dump_json()}")

async with aiofiles.open(args.inputfile, "r") as fp:
inputs = await fp.read()

input = IOMapperInput.model_validate_json(inputs)
logging.info(f"Loaded input from {args.inputfile}: {input.model_dump_json()}")

p = AgentIOMapper(config, jinja_env)
output = await p.ainvoke(input)
outputs = output.model_dump_json()

async with aiofiles.open(args.outputfile, "w") as fp:
await fp.write(outputs)

logging.info(f"Dumped output to {args.outputfile}: {outputs}")


if __name__ == "__main__":
load_dotenv(dotenv_path=find_dotenv(usecwd=True))
asyncio.run(main())
@abstractmethod
async def _ainvoke(
self, input: AgentIOMapperInput, messages: list[dict[str, str]], **kwargs
) -> str:
"""Async invoke internal model to process messages.
Args:
messages: the messages to send to the LLM
"""
Loading