From a27ac15d935258e33b173ffc4b65e56bb876c032 Mon Sep 17 00:00:00 2001 From: pakelley Date: Fri, 14 Mar 2025 16:26:04 -0700 Subject: [PATCH 1/3] fix: use defaultdict for input_field_types instead of dict --- adala/utils/parse.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/adala/utils/parse.py b/adala/utils/parse.py index 3a07dd3a..f5f7dd90 100644 --- a/adala/utils/parse.py +++ b/adala/utils/parse.py @@ -288,16 +288,14 @@ class MessagesBuilder(BaseModel): instruction_first: bool = True extra_fields: Dict[str, Any] = Field(default_factory=dict) split_into_chunks: bool = False - input_field_types: Optional[Dict[str, MessageChunkType]] = Field(default=None) + input_field_types: DefaultDict[str, Annotated[MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT)]] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) + def get_messages(self, payload: Dict[str, Any]): if self.split_into_chunks: - input_field_types = self.input_field_types or defaultdict( - lambda: MessageChunkType.TEXT - ) user_prompt = split_message_into_chunks( self.user_prompt_template, - input_field_types, + self.input_field_types, **payload, **self.extra_fields, ) From 9f06a058aefbfc88c44c50ce439f98c26ecf5fd3 Mon Sep 17 00:00:00 2001 From: pakelley Date: Fri, 14 Mar 2025 16:43:35 -0700 Subject: [PATCH 2/3] lint --- adala/utils/parse.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/adala/utils/parse.py b/adala/utils/parse.py index f5f7dd90..fcc3658c 100644 --- a/adala/utils/parse.py +++ b/adala/utils/parse.py @@ -1,20 +1,24 @@ +import logging import re import string -import logging +from collections import defaultdict +from enum import Enum from typing import ( - List, - TypedDict, + Annotated, + Any, + DefaultDict, Dict, + Generator, + Iterable, + List, + Literal, + Mapping, Optional, + TypedDict, Union, - Literal, - Iterable, - Generator, - Any, ) -from enum import Enum -from pydantic import BaseModel, Field -from collections import defaultdict + +from pydantic import BaseModel, Field, validator logger = logging.getLogger(__name__) @@ -288,8 +292,12 @@ class MessagesBuilder(BaseModel): instruction_first: bool = True extra_fields: Dict[str, Any] = Field(default_factory=dict) split_into_chunks: bool = False - input_field_types: DefaultDict[str, Annotated[MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT)]] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) - + input_field_types: DefaultDict[ + str, + Annotated[ + MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT) + ], + ] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) def get_messages(self, payload: Dict[str, Any]): if self.split_into_chunks: From 67762b3a98ed4a263292952dfb18bbf6207b80cb Mon Sep 17 00:00:00 2001 From: pakelley Date: Mon, 17 Mar 2025 10:24:02 -0700 Subject: [PATCH 3/3] make input_field_types optional again --- adala/utils/parse.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/adala/utils/parse.py b/adala/utils/parse.py index e6b8382d..7b7052d5 100644 --- a/adala/utils/parse.py +++ b/adala/utils/parse.py @@ -318,13 +318,21 @@ class MessagesBuilder(BaseModel): instruction_first: bool = True extra_fields: Dict[str, Any] = Field(default_factory=dict) split_into_chunks: bool = False - input_field_types: DefaultDict[ - str, - Annotated[ - MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT) - ], + input_field_types: Optional[ + DefaultDict[ + str, + Annotated[ + MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT) + ], + ] ] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) + @validator("input_field_types", pre=True) + def set_default_input_field_types(cls, value): + if value is None: + return defaultdict(lambda: MessageChunkType.TEXT) + return value + def get_messages(self, payload: Dict[str, Any]): if self.split_into_chunks: user_prompt = split_message_into_chunks(