1919from dataclasses import asdict , field
2020from pathlib import Path
2121from typing import Any , Dict , List , Optional
22+ from itertools import zip_longest
2223
2324import yaml
2425
@@ -217,6 +218,7 @@ def fill(
217218 prefix_generation_to_response : bool = False ,
218219 continue_prefix_generation : bool = False ,
219220 multi_turn_key : str | None = None ,
221+ return_templated_dict : bool = False ,
220222 ) -> str | List [dict ]:
221223 """
222224 Fills the prompt with the input_dict.
@@ -230,6 +232,9 @@ def fill(
230232 multi_turn_key: If specified, will read the list from input_dict[multi_turn_key]
231233 and use it to construct the prompt. You input_dict should also have "assistant" key in all
232234 turns except last containing assistant reply.
235+ return_templated_dict: Indicates whether to return a messages list where the template is used
236+ to fill the prompt. If so, a list of dicts with 'role' and 'content' keys will be returned.
237+ In this case the final user and assistant messages will include special tokens.
233238
234239 Returns:
235240 The filled prompt - either a string or a list of dictionaries.
@@ -243,36 +248,61 @@ def fill(
243248
244249 if self .config .template :
245250 if multi_turn_key is None :
246- prompt_string = self .SYSTEM_FORMAT .format (
251+ prompt_string = ( system_string := self .SYSTEM_FORMAT .format (
247252 system = self .config .system .format (** input_dict ), ** asdict (self .config .template )
248- )
249- prompt_string += self .TURN_BEGIN_FORMAT .format (
253+ ))
254+ prompt_string += ( user_string := self .TURN_BEGIN_FORMAT .format (
250255 user = self .build_user_message (input_dict ), ** asdict (self .config .template )
251- )
256+ ))
257+ user_strings = [user_string ]
258+ assistant_strings = []
252259 if generation :
253260 # Generation can be part of the input in cases such as reward models
254261 if continue_prefix_generation :
255262 # Append generation without the closing tag.
256- prompt_string += generation
263+ prompt_string += ( assistant_string := generation )
257264 else :
258- prompt_string += self .TURN_END_FORMAT .format (
265+ prompt_string += ( assistant_string := self .TURN_END_FORMAT .format (
259266 assistant = generation , ** asdict (self .config .template )
260- )
267+ ))
268+ assistant_strings .append (assistant_string )
269+
261270 else :
262- prompt_string = self .SYSTEM_FORMAT .format (
271+ prompt_string = ( system_string := self .SYSTEM_FORMAT .format (
263272 system = self .config .system .format (** input_dict ), ** asdict (self .config .template )
264- )
273+ ))
274+ user_strings = []
275+ assistant_strings = []
265276 for turn in input_dict [multi_turn_key ][:- 1 ]:
266- prompt_string += self .TURN_BEGIN_FORMAT .format (
277+ prompt_string += ( user_string := self .TURN_BEGIN_FORMAT .format (
267278 user = self .build_user_message (turn ), ** asdict (self .config .template )
268- )
269- prompt_string += self .TURN_END_FORMAT .format (
279+ ))
280+ user_strings .append (user_string )
281+ prompt_string += (assistant_string := self .TURN_END_FORMAT .format (
270282 assistant = turn ["assistant" ], ** asdict (self .config .template )
271- )
272- prompt_string += self .TURN_BEGIN_FORMAT .format (
283+ ))
284+ assistant_strings .append (assistant_string )
285+
286+ prompt_string += (user_string := self .TURN_BEGIN_FORMAT .format (
273287 user = self .build_user_message (input_dict [multi_turn_key ][- 1 ]), ** asdict (self .config .template )
274- )
288+ ))
289+ user_strings .append (user_string )
275290 prompt_string += generation
291+ if generation :
292+ assistant_strings .append (generation )
293+
294+ if return_templated_dict :
295+ messages = [
296+ {'role' : 'system' , 'content' : system_string },
297+ ]
298+
299+ for user_msg , assistant_msg in zip_longest (user_strings , assistant_strings , fillvalue = None ):
300+ if user_msg is not None :
301+ messages .append ({'role' : 'user' , 'content' : user_msg })
302+ if assistant_msg is not None :
303+ messages .append ({'role' : 'assistant' , 'content' : assistant_msg })
304+
305+ return messages
276306 return prompt_string
277307 else :
278308 if multi_turn_key is None :
0 commit comments