Skip to content

Commit 49fa971

Browse files
fix: DIA-2021: use defaultdict for input_field_types instead of dict (#358)
Co-authored-by: matt-bernstein <[email protected]>
1 parent eef8e0c commit 49fa971

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

Diff for: adala/utils/parse.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1+
import logging
12
import re
23
import string
3-
import logging
4+
from collections import defaultdict
5+
from enum import Enum
46
from typing import (
5-
List,
6-
TypedDict,
7+
Annotated,
8+
Any,
9+
DefaultDict,
710
Dict,
11+
Generator,
12+
Iterable,
13+
List,
14+
Literal,
15+
Mapping,
816
Optional,
17+
TypedDict,
918
Union,
10-
Literal,
11-
Iterable,
12-
Generator,
13-
Any,
1419
)
15-
from enum import Enum
16-
from pydantic import BaseModel, Field
17-
from collections import defaultdict
20+
21+
from pydantic import BaseModel, Field, validator
1822

1923
logger = logging.getLogger(__name__)
2024

@@ -314,16 +318,26 @@ class MessagesBuilder(BaseModel):
314318
instruction_first: bool = True
315319
extra_fields: Dict[str, Any] = Field(default_factory=dict)
316320
split_into_chunks: bool = False
317-
input_field_types: Optional[Dict[str, MessageChunkType]] = Field(default=None)
321+
input_field_types: Optional[
322+
DefaultDict[
323+
str,
324+
Annotated[
325+
MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT)
326+
],
327+
]
328+
] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT))
329+
330+
@validator("input_field_types", pre=True)
331+
def set_default_input_field_types(cls, value):
332+
if value is None:
333+
return defaultdict(lambda: MessageChunkType.TEXT)
334+
return value
318335

319336
def get_messages(self, payload: Dict[str, Any]):
320337
if self.split_into_chunks:
321-
input_field_types = self.input_field_types or defaultdict(
322-
lambda: MessageChunkType.TEXT
323-
)
324338
user_prompt = split_message_into_chunks(
325339
input_template=self.user_prompt_template,
326-
input_field_types=input_field_types,
340+
input_field_types=self.input_field_types,
327341
**payload,
328342
**self.extra_fields,
329343
)

0 commit comments

Comments
 (0)