diff --git a/adala/utils/parse.py b/adala/utils/parse.py index 0a6f7880..7b7052d5 100644 --- a/adala/utils/parse.py +++ b/adala/utils/parse.py @@ -1,20 +1,24 @@ +import logging import re import string -import logging +from collections import defaultdict +from enum import Enum from typing import ( - List, - TypedDict, + Annotated, + Any, + DefaultDict, Dict, + Generator, + Iterable, + List, + Literal, + Mapping, Optional, + TypedDict, Union, - Literal, - Iterable, - Generator, - Any, ) -from enum import Enum -from pydantic import BaseModel, Field -from collections import defaultdict + +from pydantic import BaseModel, Field, validator logger = logging.getLogger(__name__) @@ -314,16 +318,26 @@ class MessagesBuilder(BaseModel): instruction_first: bool = True extra_fields: Dict[str, Any] = Field(default_factory=dict) split_into_chunks: bool = False - input_field_types: Optional[Dict[str, MessageChunkType]] = Field(default=None) + input_field_types: Optional[ + DefaultDict[ + str, + Annotated[ + MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT) + ], + ] + ] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) + + @validator("input_field_types", pre=True) + def set_default_input_field_types(cls, value): + if value is None: + return defaultdict(lambda: MessageChunkType.TEXT) + return value def get_messages(self, payload: Dict[str, Any]): if self.split_into_chunks: - input_field_types = self.input_field_types or defaultdict( - lambda: MessageChunkType.TEXT - ) user_prompt = split_message_into_chunks( input_template=self.user_prompt_template, - input_field_types=input_field_types, + input_field_types=self.input_field_types, **payload, **self.extra_fields, )