1
1
# SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
2
2
# SPDX-License-Identifier: Apache-2.0
3
- import argparse
4
- import asyncio
5
3
import json
6
4
import logging
7
5
import re
8
- from typing import ClassVar , TypedDict
6
+ from abc import abstractmethod
7
+ from typing import ClassVar
9
8
10
- import aiofiles
11
9
import jsonschema
12
- from dotenv import find_dotenv , load_dotenv
13
10
from jinja2 import Environment
14
11
from jinja2 .sandbox import SandboxedEnvironment
15
- from pydantic import BaseModel , Field , model_validator
16
- from pydantic_ai import Agent
17
- from typing_extensions import Self
12
+ from pydantic import Field
18
13
19
- from .base import BaseIOMapper , IOMapperInput , IOMapperOutput , IOModelSettings
20
- from .supported_agents import get_supported_agent
14
+ from .base import (
15
+ BaseIOMapper ,
16
+ BaseIOMapperConfig ,
17
+ BaseIOMapperInput ,
18
+ BaseIOMapperOutput ,
19
+ )
21
20
22
21
logger = logging .getLogger (__name__ )
23
22
24
23
25
- class IOModelArgs (TypedDict , total = False ):
26
- base_url : str
27
- api_version : str
28
- azure_endpoint : str
29
- azure_ad_token : str
30
- project : str
31
- organization : str
24
+ class AgentIOMapperInput (BaseIOMapperInput ):
25
+ message_template : str | None = Field (
26
+ max_length = 4096 ,
27
+ default = None ,
28
+ description = "Message (user) to send to LLM to effect translation." ,
29
+ )
32
30
33
31
34
- class IOMapperConfig (BaseModel ):
35
- models : dict [str , IOModelArgs ] = Field (
36
- default = {"azure:gpt-4o-mini" : IOModelArgs ()},
37
- description = "LLM configuration to use for translation" ,
38
- )
39
- default_model : str | None = Field (
40
- default = "azure:gpt-4o-mini" ,
41
- description = "Default arguments to LLM completion function by configured model." ,
42
- )
43
- default_model_settings : dict [str , IOModelSettings ] = Field (
44
- default = {"azure:gpt-4o-mini" : IOModelSettings (seed = 42 , temperature = 0.8 )},
45
- description = "LLM configuration to use for translation" ,
46
- )
47
- validate_json_input : bool = Field (
48
- default = False , description = "Validate input against JSON schema."
49
- )
50
- validate_json_output : bool = Field (
51
- default = False , description = "Validate output against JSON schema."
52
- )
32
+ AgentIOMapperOutput = BaseIOMapperOutput
33
+
34
+
35
+ class AgentIOMapperConfig (BaseIOMapperConfig ):
53
36
system_prompt_template : str = Field (
54
37
max_length = 4096 ,
55
38
default = "You are a translation machine. You translate both natural language and object formats for computers." ,
@@ -61,19 +44,6 @@ class IOMapperConfig(BaseModel):
61
44
description = "Default user message template. This can be overridden by the message request." ,
62
45
)
63
46
64
- @model_validator (mode = "after" )
65
- def _validate_obj (self ) -> Self :
66
- if self .models and self .default_model not in self .models :
67
- raise ValueError (
68
- f"default model { self .default_model } not present in configured models"
69
- )
70
- # Fill out defaults to eliminate need for checking.
71
- for model_name in self .models .keys ():
72
- if model_name not in self .default_model_settings :
73
- self .default_model_settings [model_name ] = IOModelSettings ()
74
-
75
- return self
76
-
77
47
78
48
class AgentIOMapper (BaseIOMapper ):
79
49
_json_search_pattern : ClassVar [re .Pattern ] = re .compile (
@@ -82,11 +52,13 @@ class AgentIOMapper(BaseIOMapper):
82
52
83
53
def __init__ (
84
54
self ,
85
- config : IOMapperConfig ,
55
+ config : AgentIOMapperConfig | None = None ,
86
56
jinja_env : Environment | None = None ,
87
57
jinja_env_async : Environment | None = None ,
88
58
):
89
- self .config = config
59
+ if config is None :
60
+ config = AgentIOMapperConfig ()
61
+ super ().__init__ (config )
90
62
91
63
if jinja_env is not None and jinja_env .is_async :
92
64
raise ValueError ("Async Jinja env passed to jinja_env argument" )
@@ -136,41 +108,19 @@ def _check_jinja_env(self, enable_async: bool):
136
108
self .config .message_template
137
109
)
138
110
139
- def _get_render_env (self , input : IOMapperInput ) -> dict [str , str ]:
111
+ def _get_render_env (self , input : AgentIOMapperInput ) -> dict [str , str ]:
140
112
return {
141
113
"input" : input .input ,
142
114
"output" : input .output ,
143
115
"data" : input .data ,
144
116
}
145
117
146
- def _get_model_settings (self , input : IOMapperInput ):
147
- model_name = input .model or self .config .default_model
148
- if model_name not in self .config .models :
149
- raise ValueError (f"requested model { model_name } not found" )
150
- elif input .model_settings is None :
151
- return self .config .default_model_settings [model_name ]
152
- else :
153
- model_settings = self .config .default_model_settings [model_name ].copy ()
154
- model_settings .update (input .model_settings )
155
- return model_settings
156
-
157
- def _get_agent (
158
- self , is_async : bool , input : IOMapperInput , system_prompt : str
159
- ) -> Agent :
160
- model_name = input .model or self .config .default_model
161
- if model_name not in self .config .models :
162
- raise ValueError (f"requested model { model_name } not found" )
163
-
164
- return get_supported_agent (
165
- model_name ,
166
- model_args = self .config .models [model_name ],
167
- system_prompt = system_prompt ,
168
- )
169
-
170
- def _get_output (self , input : IOMapperInput , outputs : str ) -> IOMapperOutput :
118
+ def _get_output (
119
+ self , input : AgentIOMapperInput , outputs : str
120
+ ) -> AgentIOMapperOutput :
171
121
if input .output .json_schema is None :
172
122
# If there is no schema, quote the chars for JSON.
173
- return IOMapperOutput .model_validate_json (
123
+ return AgentIOMapperOutput .model_validate_json (
174
124
f'{{"data": { json .dumps (outputs )} }}'
175
125
)
176
126
@@ -179,9 +129,9 @@ def _get_output(self, input: IOMapperInput, outputs: str) -> IOMapperOutput:
179
129
if matches :
180
130
outputs = matches [- 1 ]
181
131
182
- return IOMapperOutput .model_validate_json (f'{{"data": { outputs } }}' )
132
+ return AgentIOMapperOutput .model_validate_json (f'{{"data": { outputs } }}' )
183
133
184
- def _validate_input (self , input : IOMapperInput ) -> None :
134
+ def _validate_input (self , input : AgentIOMapperInput ) -> None :
185
135
if self .config .validate_json_input and input .input .json_schema is not None :
186
136
jsonschema .validate (
187
137
instance = input .data ,
@@ -190,7 +140,9 @@ def _validate_input(self, input: IOMapperInput) -> None:
190
140
),
191
141
)
192
142
193
- def _validate_output (self , input : IOMapperInput , output : IOMapperOutput ) -> None :
143
+ def _validate_output (
144
+ self , input : AgentIOMapperInput , output : AgentIOMapperOutput
145
+ ) -> None :
194
146
if self .config .validate_json_output and input .output .json_schema is not None :
195
147
output_schema = input .output .json_schema .model_dump (
196
148
exclude_none = True , mode = "json"
@@ -201,34 +153,46 @@ def _validate_output(self, input: IOMapperInput, output: IOMapperOutput) -> None
201
153
schema = output_schema ,
202
154
)
203
155
204
- def invoke (self , input : IOMapperInput ) -> IOMapperOutput :
156
+ def invoke (self , input : AgentIOMapperInput , ** kwargs ) -> AgentIOMapperOutput :
205
157
self ._validate_input (input )
206
158
self ._check_jinja_env (False )
207
159
render_env = self ._get_render_env (input )
208
160
system_prompt = self .prompt_template .render (render_env )
209
- agent = self ._get_agent (False , input , system_prompt )
210
161
211
162
if input .message_template is not None :
212
163
logging .info (f"User template supplied on input: { input .message_template } " )
213
164
user_template = self .jinja_env .from_string (input .message_template )
214
165
else :
215
166
user_template = self .user_template
216
- response = agent .run_sync (
217
- user_prompt = user_template .render (render_env ),
218
- model_settings = self ._get_model_settings (input ),
167
+ user_prompt = user_template .render (render_env )
168
+
169
+ outputs = self ._invoke (
170
+ input ,
171
+ messages = [
172
+ {"role" : "system" , "content" : system_prompt },
173
+ {"role" : "user" , "content" : user_prompt },
174
+ ],
175
+ ** kwargs ,
219
176
)
220
- outputs = response .data
221
177
logging .debug (f"The LLM returned: { outputs } " )
222
178
output = self ._get_output (input , outputs )
223
179
self ._validate_output (input , output )
224
180
return output
225
181
226
- async def ainvoke (self , input : IOMapperInput ) -> IOMapperOutput :
182
+ @abstractmethod
183
+ def _invoke (
184
+ self , input : AgentIOMapperInput , messages : list [dict [str , str ]], ** kwargs
185
+ ) -> str :
186
+ """Invoke internal model to process messages.
187
+ Args:
188
+ messages: the messages to send to the LLM
189
+ """
190
+
191
+ async def ainvoke (self , input : AgentIOMapperInput , ** kwargs ) -> AgentIOMapperOutput :
227
192
self ._validate_input (input )
228
193
self ._check_jinja_env (True )
229
194
render_env = self ._get_render_env (input )
230
195
system_prompt = await self .prompt_template_async .render_async (render_env )
231
- agent = self ._get_agent (True , input , system_prompt )
232
196
233
197
if input .message_template is not None :
234
198
logging .info (f"User template supplied on input: { input .message_template } " )
@@ -237,53 +201,26 @@ async def ainvoke(self, input: IOMapperInput) -> IOMapperOutput:
237
201
)
238
202
else :
239
203
user_template_async = self .user_template_async
240
- response = await agent .run (
241
- user_prompt = await user_template_async .render_async (render_env ),
242
- model_settings = self ._get_model_settings (input ),
204
+ user_prompt = await user_template_async .render_async (render_env )
205
+
206
+ outputs = await self ._ainvoke (
207
+ input ,
208
+ messages = [
209
+ {"role" : "system" , "content" : system_prompt },
210
+ {"role" : "user" , "content" : user_prompt },
211
+ ],
212
+ ** kwargs ,
243
213
)
244
- outputs = response .data
245
214
logging .debug (f"The LLM returned: { outputs } " )
246
215
output = self ._get_output (input , outputs )
247
216
self ._validate_output (input , output )
248
217
return output
249
218
250
-
251
- async def main ():
252
- parser = argparse .ArgumentParser ()
253
- parser .add_argument ("--inputfile" , help = "Inputfile" , required = True )
254
- parser .add_argument ("--configfile" , help = "Configuration file" , required = True )
255
- parser .add_argument ("--outputfile" , help = "Output file" , required = True )
256
- args = parser .parse_args ()
257
- logging .basicConfig (format = "%(levelname)s:%(message)s" , level = logging .INFO )
258
-
259
- jinja_env = SandboxedEnvironment (
260
- loader = None ,
261
- enable_async = True ,
262
- autoescape = False ,
263
- )
264
-
265
- async with aiofiles .open (args .configfile , "r" ) as fp :
266
- configs = await fp .read ()
267
-
268
- config = IOMapperConfig .model_validate_json (configs )
269
- logging .info (f"Loaded config from { args .configfile } : { config .model_dump_json ()} " )
270
-
271
- async with aiofiles .open (args .inputfile , "r" ) as fp :
272
- inputs = await fp .read ()
273
-
274
- input = IOMapperInput .model_validate_json (inputs )
275
- logging .info (f"Loaded input from { args .inputfile } : { input .model_dump_json ()} " )
276
-
277
- p = AgentIOMapper (config , jinja_env )
278
- output = await p .ainvoke (input )
279
- outputs = output .model_dump_json ()
280
-
281
- async with aiofiles .open (args .outputfile , "w" ) as fp :
282
- await fp .write (outputs )
283
-
284
- logging .info (f"Dumped output to { args .outputfile } : { outputs } " )
285
-
286
-
287
- if __name__ == "__main__" :
288
- load_dotenv (dotenv_path = find_dotenv (usecwd = True ))
289
- asyncio .run (main ())
219
+ @abstractmethod
220
+ async def _ainvoke (
221
+ self , input : AgentIOMapperInput , messages : list [dict [str , str ]], ** kwargs
222
+ ) -> str :
223
+ """Async invoke internal model to process messages.
224
+ Args:
225
+ messages: the messages to send to the LLM
226
+ """
0 commit comments