@@ -53,10 +53,12 @@ class GenerateSolutionsConfig:
5353
5454 input_file : str # Path to the input file with data
5555 output_file : str # Where to save the generations
56- prompt_config : str | None = None # How to format the data into prompts
56+ prompt_config : str | None = None # How to format the data into prompts
5757 prompt_template : str | None = None # not required for OpenAI server
58- prompt_format : str = "ns" # to specify the format of the prompt, "ns" for NeMo-Skills format or "openai" for OpenAI chat format
59- code_tags : str | None = None # required when using code execution
58+ # to specify the format of the prompt, "ns" for NeMo-Skills format or "openai" for OpenAI chat format
59+ prompt_format : str = "ns"
60+ system_message : str | None = None # can override the default system message in the config
61+ code_tags : str | None = None # required when using code execution
6062 examples_type : str | None = None # to be able to customize few-shot examples
6163
6264 # Inference server configuration {server_params}
@@ -150,10 +152,11 @@ def _post_init_validate_params(self):
150152 """Validate that certain parameters are restricted to certain values"""
151153 if self .prompt_format not in ["ns" , "openai" ]:
152154 raise ValueError (f"prompt_format must be either 'ns' or 'openai', got '{ self .prompt_format } '" )
153-
155+
154156 if self .prompt_format == "openai" :
155157 assert self .prompt_config is None , "prompt_config is not supported for prompt_format == 'openai'"
156158 assert self .prompt_template is None , "prompt_template is not supported for prompt_format == 'openai'"
159+ assert self .system_message is None , "system_message is not supported for prompt_format == 'openai'"
157160 else :
158161 assert self .prompt_config is not None , "prompt_config is required when prompt_format == 'ns'"
159162 for param , default_value in self ._get_disallowed_params ():
@@ -241,8 +244,7 @@ def __init__(self, cfg: GenerateSolutionsConfig):
241244 )
242245
243246 def setup_llm (self ):
244- if (self .cfg .prompt_template is None
245- and self .cfg .server ["server_type" ] not in ["openai" , "vllm" , "sglang" ]):
247+ if self .cfg .prompt_template is None and self .cfg .server ["server_type" ] not in ["openai" , "vllm" , "sglang" ]:
246248 with open_dict (self .cfg .server ):
247249 self .cfg .server ["server_type" ] = "openai"
248250 self .cfg .server ["model" ] = "model"
@@ -261,19 +263,23 @@ def setup_prompt(self):
261263
262264 if self .cfg .prompt_format == "openai" :
263265 return None
264-
265- prompt = get_prompt (self .cfg .prompt_config , self .cfg .prompt_template , self .cfg .code_tags , examples_type = self .cfg .examples_type )
266+
267+ prompt = get_prompt (
268+ self .cfg .prompt_config , self .cfg .prompt_template , self .cfg .code_tags , examples_type = self .cfg .examples_type
269+ )
270+ if self .cfg .system_message is not None :
271+ prompt .config .system = self .cfg .system_message
266272 LOG .info ("Prompt used: %s" , prompt )
267273 return prompt
268274
269275 def log_example_prompt (self , data ):
270276 data_point = deepcopy (data [0 ])
271277
272278 if self .cfg .prompt_format == "openai" :
273- #print the prompt in openai format
279+ # print the prompt in openai format
274280 LOG .info ("Example prompt in OpenAI format: \n Data dictionary: %s" , data_point )
275281 return
276-
282+
277283 if self .cfg .multi_turn_key is None :
278284 LOG .info (
279285 "Example prompt:\n Data dictionary: %s\n Prompt: %s" , data_point , self .fill_prompt (data_point , data )
@@ -374,7 +380,7 @@ def fill_prompt(self, data_point, data):
374380 """Passing in full data in case it's needed to fill the prompt in subclasses."""
375381 if self .cfg .prompt_format == "openai" :
376382 return data_point ["messages" ]
377-
383+
378384 total_code_executions_in_prompt = self .cfg .total_code_executions_in_prompt
379385 if total_code_executions_in_prompt is not None :
380386 if isinstance (total_code_executions_in_prompt , (list , tuple )):
@@ -394,8 +400,7 @@ def llm_generate(self, data_points, data, is_async=False):
394400 generation_params = {
395401 "prompts" : [self .fill_prompt (dp , data ) for dp in data_points ],
396402 "stop_phrases" : combine_stop_phrases (
397- self .prompt .stop_phrases if self .prompt is not None else None ,
398- self .extra_stop_phrases
403+ self .prompt .stop_phrases if self .prompt is not None else None , self .extra_stop_phrases
399404 ),
400405 ** asdict (self .cfg .inference ),
401406 ** self .extra_generate_params ,
0 commit comments