|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | import logging
|
5 |
| -from typing import Any, Optional, Tuple, Union |
| 5 | +from typing import Any, Callable, List, Optional, Union |
6 | 6 |
|
7 | 7 | from langchain_core.language_models import BaseChatModel
|
8 | 8 | from langchain_core.runnables import Runnable
|
9 |
| -from openapi_pydantic import Schema |
| 9 | +from llama_index.core.base.llms.base import BaseLLM |
| 10 | +from llama_index.core.tools import BaseTool |
| 11 | +from llama_index.core.workflow import ( |
| 12 | + Workflow, |
| 13 | +) |
10 | 14 | from pydantic import BaseModel, Field, model_validator
|
11 | 15 | from typing_extensions import Self
|
12 | 16 |
|
13 |
| -from agntcy_iomapper.agent.models import ( |
| 17 | +from agntcy_iomapper.base.models import ( |
14 | 18 | AgentIOMapperInput,
|
| 19 | + ArgumentsDescription, |
15 | 20 | FieldMetadata,
|
16 | 21 | IOMappingAgentMetadata,
|
17 | 22 | )
|
18 |
| -from agntcy_iomapper.base import ArgumentsDescription |
19 |
| -from agntcy_iomapper.base.utils import create_type_from_schema, extract_nested_fields |
| 23 | +from agntcy_iomapper.base.utils import extract_nested_fields, get_io_types |
| 24 | +from agntcy_iomapper.imperative import ( |
| 25 | + ImperativeIOMapper, |
| 26 | + ImperativeIOMapperInput, |
| 27 | +) |
20 | 28 | from agntcy_iomapper.langgraph import LangGraphIOMapper, LangGraphIOMapperConfig
|
| 29 | +from agntcy_iomapper.llamaindex.llamaindex import ( |
| 30 | + LLamaIndexIOMapper, |
| 31 | +) |
21 | 32 |
|
22 | 33 | logger = logging.getLogger(__name__)
|
23 | 34 |
|
24 | 35 |
|
25 | 36 | class IOMappingAgent(BaseModel):
|
26 |
| - metadata: IOMappingAgentMetadata = Field( |
| 37 | + metadata: Optional[IOMappingAgentMetadata] = Field( |
27 | 38 | ...,
|
28 | 39 | description="Details about the fields to be used in the translation and about the output",
|
29 | 40 | )
|
30 |
| - llm: Optional[Union[BaseChatModel, str]] = ( |
31 |
| - Field( |
32 |
| - None, |
33 |
| - description="Model to use for translation as LangChain description or model class.", |
34 |
| - ), |
| 41 | + llm: Optional[Union[BaseChatModel, str]] = Field( |
| 42 | + None, |
| 43 | + description="Model to use for translation as LangChain description or model class.", |
35 | 44 | )
|
36 | 45 |
|
37 | 46 | @model_validator(mode="after")
|
38 | 47 | def _validate_obj(self) -> Self:
|
| 48 | + if not self.metadata: |
| 49 | + return self |
39 | 50 |
|
40 | 51 | if not self.metadata.input_fields or len(self.metadata.input_fields) == 0:
|
41 | 52 | raise ValueError("input_fields not found in the metadata")
|
@@ -70,38 +81,10 @@ def _validate_obj(self) -> Self:
|
70 | 81 |
|
71 | 82 | return self
|
72 | 83 |
|
73 |
| - def _get_io_types(self, data: Any) -> Tuple[Schema, Schema]: |
74 |
| - data_schema = None |
75 |
| - if isinstance(data, BaseModel): |
76 |
| - data_schema = data.model_json_schema() |
77 |
| - # If input schema is provided it overwrites the data schema |
78 |
| - input_schema = ( |
79 |
| - self.metadata.input_schema if self.metadata.input_schema else data_schema |
80 |
| - ) |
81 |
| - # If output schema is provided it overwrites the data schema |
82 |
| - output_schema = ( |
83 |
| - self.metadata.output_schema if self.metadata.output_schema else data_schema |
84 |
| - ) |
85 |
| - |
86 |
| - if not input_schema or not output_schema: |
87 |
| - raise ValueError( |
88 |
| - "input_schema, and or output_schema are missing from the metadata, for a better accuracy you are required to provide them in this scenario, or we could not infer the type from the state" |
89 |
| - ) |
90 |
| - |
91 |
| - input_type = Schema.model_validate( |
92 |
| - create_type_from_schema(input_schema, self.metadata.input_fields) |
93 |
| - ) |
94 |
| - |
95 |
| - output_type = Schema.model_validate( |
96 |
| - create_type_from_schema(output_schema, self.metadata.output_fields) |
97 |
| - ) |
98 |
| - |
99 |
| - return (input_type, output_type) |
100 |
| - |
101 | 84 | def langgraph_node(self, data: Any, config: Optional[dict] = None) -> Runnable:
|
102 | 85 |
|
103 | 86 | # If there is a template for the output the output_schema is going to be ignored in the translation
|
104 |
| - input_type, output_type = self._get_io_types(data) |
| 87 | + input_type, output_type = get_io_types(data, self.metadata) |
105 | 88 |
|
106 | 89 | data_to_be_mapped = extract_nested_fields(
|
107 | 90 | data, fields=self.metadata.input_fields
|
@@ -132,3 +115,54 @@ def langgraph_node(self, data: Any, config: Optional[dict] = None) -> Runnable:
|
132 | 115 |
|
133 | 116 | iomapper_config = LangGraphIOMapperConfig(llm=self.llm)
|
134 | 117 | return LangGraphIOMapper(iomapper_config, input).as_runnable()
|
| 118 | + |
| 119 | + def langgraph_imperative( |
| 120 | + self, data: Any, config: Optional[dict] = None |
| 121 | + ) -> Runnable: |
| 122 | + |
| 123 | + input_type, output_type = self._get_io_types(data, self.metadata) |
| 124 | + |
| 125 | + data_to_be_mapped = extract_nested_fields( |
| 126 | + data, fields=self.metadata.input_fields |
| 127 | + ) |
| 128 | + |
| 129 | + input = ImperativeIOMapperInput( |
| 130 | + input=ArgumentsDescription( |
| 131 | + json_schema=input_type, |
| 132 | + ), |
| 133 | + output=ArgumentsDescription(json_schema=output_type), |
| 134 | + data=data_to_be_mapped, |
| 135 | + ) |
| 136 | + |
| 137 | + if not self.metadata.field_mapping: |
| 138 | + raise ValueError( |
| 139 | + "In order to use imperative mapping field_mapping must be provided in the metadata" |
| 140 | + ) |
| 141 | + |
| 142 | + imperative_io_mapper = ImperativeIOMapper( |
| 143 | + input=input, field_mapping=self.metadata.field_mapping |
| 144 | + ) |
| 145 | + return imperative_io_mapper.as_runnable() |
| 146 | + |
| 147 | + @staticmethod |
| 148 | + def as_worfklow_step(workflow: Workflow) -> Callable: |
| 149 | + io_mapper_step = LLamaIndexIOMapper.llamaindex_mapper(workflow) |
| 150 | + return io_mapper_step |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def as_workflow_agent( |
| 154 | + mapping_metadata: IOMappingAgentMetadata, |
| 155 | + llm: BaseLLM, |
| 156 | + name: str, |
| 157 | + description: str, |
| 158 | + can_handoff_to: Optional[List[str]] = None, |
| 159 | + tools: Optional[List[Union[BaseTool, Callable]]] = [], |
| 160 | + ): |
| 161 | + return LLamaIndexIOMapper( |
| 162 | + mapping_metadata=mapping_metadata, |
| 163 | + llm=llm, |
| 164 | + tools=tools, |
| 165 | + name=name, |
| 166 | + description=description, |
| 167 | + can_handoff_to=can_handoff_to, |
| 168 | + ) |
0 commit comments