1- __author__ = "aiXplain"
1+ """Agent module for aiXplain SDK.
2+
3+ This module provides the Agent class and related functionality for creating and managing
4+ AI agents that can execute tasks using various tools and models.
25
3- """
46Copyright 2024 The aiXplain SDK authors
57
68Licensed under the Apache License, Version 2.0 (the "License");
2022Description:
2123 Agentification Class
2224"""
25+
26+ __author__ = "aiXplain"
2327import json
2428import logging
2529import re
3337from aixplain .modules .model import Model
3438from aixplain .modules .agent .agent_task import WorkflowTask , AgentTask
3539from aixplain .modules .agent .output_format import OutputFormat
36- from aixplain .modules .agent .tool import Tool
40+ from aixplain .modules .agent .tool import Tool , DeployableTool
3741from aixplain .modules .agent .agent_response import AgentResponse
3842from aixplain .modules .agent .agent_response_data import AgentResponseData
3943from aixplain .modules .agent .utils import process_variables , validate_history
4852import warnings
4953
5054
51- class Agent (Model , DeployableMixin [Tool ]):
55+ class Agent (Model , DeployableMixin [Union [ Tool , DeployableTool ] ]):
5256 """An advanced AI system that performs tasks using specialized tools from the aiXplain marketplace.
5357
5458 This class represents an AI agent that can understand natural language instructions,
@@ -124,6 +128,8 @@ def __init__(
124128 Defaults to AssetStatus.DRAFT.
125129 tasks (List[AgentTask], optional): List of tasks the Agent can perform.
126130 Defaults to empty list.
131+ workflow_tasks (List[WorkflowTask], optional): List of workflow tasks
132+ the Agent can execute. Defaults to empty list.
127133 output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT.
128134 expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
129135 **additional_info: Additional configuration parameters.
@@ -144,7 +150,8 @@ def __init__(
144150 self .status = status
145151 if tasks :
146152 warnings .warn (
147- "The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead." ,
153+ "The 'tasks' parameter is deprecated and will be removed in a future version. "
154+ "Use 'workflow_tasks' instead." ,
148155 DeprecationWarning ,
149156 stacklevel = 2 ,
150157 )
@@ -171,9 +178,9 @@ def _validate(self) -> None:
171178 from aixplain .utils .llm_utils import get_llm_instance
172179
173180 # validate name
174- assert (
175- re . match ( r"^[a-zA-Z0-9 \-\(\)]*$" , self . name ) is not None
176- ), "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
181+ assert re . match ( r"^[a-zA-Z0-9 \-\(\)]*$" , self . name ) is not None , (
182+ "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
183+ )
177184
178185 llm = get_llm_instance (self .llm_id , api_key = self .api_key )
179186
@@ -233,6 +240,14 @@ def validate(self, raise_exception: bool = False) -> bool:
233240 return self .is_valid
234241
235242 def generate_session_id (self , history : list = None ) -> str :
243+ """Generate a unique session ID for agent conversations.
244+
245+ Args:
246+ history (list, optional): Previous conversation history. Defaults to None.
247+
248+ Returns:
249+ str: A unique session identifier based on timestamp and random components.
250+ """
236251 if history :
237252 validate_history (history )
238253 timestamp = datetime .now ().strftime ("%Y%m%d%H%M%S" )
@@ -307,6 +322,7 @@ def run(
307322 max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.
308323 output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
309324 expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
325+
310326 Returns:
311327 Dict: parsed output from model
312328 """
@@ -406,10 +422,10 @@ def run_async(
406422 expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
407423 output_format (ResponseFormat, optional): response format. Defaults to TEXT.
408424 evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None.
425+
409426 Returns:
410427 dict: polling URL in response
411428 """
412-
413429 if session_id is not None and history is not None :
414430 raise ValueError ("Provide either `session_id` or `history`, not both." )
415431
@@ -434,7 +450,9 @@ def run_async(
434450 assert data is not None or query is not None , "Either 'data' or 'query' must be provided."
435451 if data is not None :
436452 if isinstance (data , dict ):
437- assert "query" in data and data ["query" ] is not None , "When providing a dictionary, 'query' must be provided."
453+ assert "query" in data and data ["query" ] is not None , (
454+ "When providing a dictionary, 'query' must be provided."
455+ )
438456 query = data .get ("query" )
439457 if session_id is None :
440458 session_id = data .get ("session_id" )
@@ -447,7 +465,9 @@ def run_async(
447465
448466 # process content inputs
449467 if content is not None :
450- assert FileFactory .check_storage_type (query ) == StorageType .TEXT , "When providing 'content', query must be text."
468+ assert FileFactory .check_storage_type (query ) == StorageType .TEXT , (
469+ "When providing 'content', query must be text."
470+ )
451471
452472 if isinstance (content , list ):
453473 assert len (content ) <= 3 , "The maximum number of content inputs is 3."
@@ -511,6 +531,11 @@ def run_async(
511531 )
512532
513533 def to_dict (self ) -> Dict :
534+ """Convert the Agent instance to a dictionary representation.
535+
536+ Returns:
537+ Dict: Dictionary containing the agent's configuration and metadata.
538+ """
514539 from aixplain .factories .agent_factory .utils import build_tool_payload
515540
516541 return {
@@ -674,9 +699,9 @@ def delete(self) -> None:
674699 "referencing it."
675700 )
676701 else :
677- message = f"Agent Deletion Error (HTTP { r .status_code } ): " f" { error_message } ."
702+ message = f"Agent Deletion Error (HTTP { r .status_code } ): { error_message } ."
678703 except ValueError :
679- message = f"Agent Deletion Error (HTTP { r .status_code } ): " " There was an error in deleting the agent."
704+ message = f"Agent Deletion Error (HTTP { r .status_code } ): There was an error in deleting the agent."
680705 logging .error (message )
681706 raise Exception (message )
682707
@@ -701,7 +726,7 @@ def update(self) -> None:
701726 stack = inspect .stack ()
702727 if len (stack ) > 2 and stack [1 ].function != "save" :
703728 warnings .warn (
704- "update() is deprecated and will be removed in a future version. " " Please use save() instead." ,
729+ "update() is deprecated and will be removed in a future version. Please use save() instead." ,
705730 DeprecationWarning ,
706731 stacklevel = 2 ,
707732 )
0 commit comments