|
| 1 | +import logging |
1 | 2 | import re
|
2 | 3 | import string
|
3 |
| -import logging |
| 4 | +from collections import defaultdict |
| 5 | +from enum import Enum |
4 | 6 | from typing import (
|
5 |
| - List, |
6 |
| - TypedDict, |
| 7 | + Annotated, |
| 8 | + Any, |
| 9 | + DefaultDict, |
7 | 10 | Dict,
|
| 11 | + Generator, |
| 12 | + Iterable, |
| 13 | + List, |
| 14 | + Literal, |
| 15 | + Mapping, |
8 | 16 | Optional,
|
| 17 | + TypedDict, |
9 | 18 | Union,
|
10 |
| - Literal, |
11 |
| - Iterable, |
12 |
| - Generator, |
13 |
| - Any, |
14 | 19 | )
|
15 |
| -from enum import Enum |
16 |
| -from pydantic import BaseModel, Field |
17 |
| -from collections import defaultdict |
| 20 | + |
| 21 | +from pydantic import BaseModel, Field, validator |
18 | 22 |
|
19 | 23 | logger = logging.getLogger(__name__)
|
20 | 24 |
|
@@ -314,16 +318,26 @@ class MessagesBuilder(BaseModel):
|
314 | 318 | instruction_first: bool = True
|
315 | 319 | extra_fields: Dict[str, Any] = Field(default_factory=dict)
|
316 | 320 | split_into_chunks: bool = False
|
317 |
| - input_field_types: Optional[Dict[str, MessageChunkType]] = Field(default=None) |
| 321 | + input_field_types: Optional[ |
| 322 | + DefaultDict[ |
| 323 | + str, |
| 324 | + Annotated[ |
| 325 | + MessageChunkType, Field(default_factory=lambda: MessageChunkType.TEXT) |
| 326 | + ], |
| 327 | + ] |
| 328 | + ] = Field(default_factory=lambda: defaultdict(lambda: MessageChunkType.TEXT)) |
| 329 | + |
| 330 | + @validator("input_field_types", pre=True) |
| 331 | + def set_default_input_field_types(cls, value): |
| 332 | + if value is None: |
| 333 | + return defaultdict(lambda: MessageChunkType.TEXT) |
| 334 | + return value |
318 | 335 |
|
319 | 336 | def get_messages(self, payload: Dict[str, Any]):
|
320 | 337 | if self.split_into_chunks:
|
321 |
| - input_field_types = self.input_field_types or defaultdict( |
322 |
| - lambda: MessageChunkType.TEXT |
323 |
| - ) |
324 | 338 | user_prompt = split_message_into_chunks(
|
325 | 339 | input_template=self.user_prompt_template,
|
326 |
| - input_field_types=input_field_types, |
| 340 | + input_field_types=self.input_field_types, |
327 | 341 | **payload,
|
328 | 342 | **self.extra_fields,
|
329 | 343 | )
|
|
0 commit comments