Skip to content

Commit f72780c

Browse files
authored
Customize system message + remove thinking from evaluation by default (#523)
Signed-off-by: Igor Gitman <igitman@nvidia.com>
1 parent 074f267 commit f72780c

3 files changed

Lines changed: 46 additions & 15 deletions

File tree

nemo_skills/evaluation/evaluate_results.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import logging
1617
import sys
1718
from dataclasses import field
@@ -20,7 +21,7 @@
2021
import hydra
2122

2223
from nemo_skills.evaluation.evaluator import evaluate
23-
from nemo_skills.utils import get_help_message, get_logger_name, nested_dataclass, setup_logging
24+
from nemo_skills.utils import get_help_message, get_logger_name, nested_dataclass, setup_logging, unroll_files
2425

2526
LOG = logging.getLogger(get_logger_name(__file__))
2627

@@ -39,12 +40,18 @@ class EvaluateResultsConfig:
3940
# check graders.py for the supported eval types and their parameters
4041
eval_config: dict = field(default_factory=dict)
4142

43+
# TODO: move lean-specific parameters to inner config
4244
# the escape phrase prior to a lean4 block to extract
4345
final_answer_key: str = field(default="### Final Answer")
44-
4546
# whether to restate the formal statement when constructing the final output proof
4647
restate_formal_statement: bool = True
4748

49+
# whether to remove the thinking part from the final output
50+
remove_thinking: bool = True
51+
52+
# thinking separator
53+
thinking_separator: str = "</think>"
54+
4855
def __post_init__(self):
4956
if isinstance(self.input_files, str):
5057
self.input_files = self.input_files.split(" ")
@@ -58,6 +65,24 @@ def __post_init__(self):
5865
def evaluate_results(cfg: EvaluateResultsConfig):
5966
cfg = EvaluateResultsConfig(_init_nested=True, **cfg)
6067
LOG.info("Config used: %s", cfg)
68+
69+
if cfg.remove_thinking:
70+
LOG.info(
71+
f'Removing the thinking part from the "generation" key (splitting on {cfg.thinking_separator}). '
72+
'Original content will be stored in "_full_generation" key.'
73+
)
74+
for jsonl_file in unroll_files(cfg.input_files):
75+
with open(jsonl_file, encoding="utf-8") as f:
76+
samples = [json.loads(line) for line in f]
77+
with open(jsonl_file, "wt", encoding="utf-8") as f:
78+
for sample in samples:
79+
if cfg.thinking_separator in sample["generation"]:
80+
sample["_full_generation"] = sample["generation"]
81+
sample["generation"] = sample["generation"].split(cfg.thinking_separator)[-1].strip()
82+
sample["_has_think_tags"] = cfg.thinking_separator in sample["generation"]
83+
84+
f.write(json.dumps(sample) + "\n")
85+
6186
evaluate(cfg)
6287

6388

nemo_skills/inference/generate.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ class GenerateSolutionsConfig:
5353

5454
input_file: str # Path to the input file with data
5555
output_file: str # Where to save the generations
56-
prompt_config: str | None = None # How to format the data into prompts
56+
prompt_config: str | None = None # How to format the data into prompts
5757
prompt_template: str | None = None # not required for OpenAI server
58-
prompt_format: str = "ns" # to specify the format of the prompt, "ns" for NeMo-Skills format or "openai" for OpenAI chat format
59-
code_tags: str | None = None # required when using code execution
58+
# to specify the format of the prompt, "ns" for NeMo-Skills format or "openai" for OpenAI chat format
59+
prompt_format: str = "ns"
60+
system_message: str | None = None # can override the default system message in the config
61+
code_tags: str | None = None # required when using code execution
6062
examples_type: str | None = None # to be able to customize few-shot examples
6163

6264
# Inference server configuration {server_params}
@@ -150,10 +152,11 @@ def _post_init_validate_params(self):
150152
"""Validate that certain parameters are restricted to certain values"""
151153
if self.prompt_format not in ["ns", "openai"]:
152154
raise ValueError(f"prompt_format must be either 'ns' or 'openai', got '{self.prompt_format}'")
153-
155+
154156
if self.prompt_format == "openai":
155157
assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'"
156158
assert self.prompt_template is None, "prompt_template is not supported for prompt_format == 'openai'"
159+
assert self.system_message is None, "system_message is not supported for prompt_format == 'openai'"
157160
else:
158161
assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'"
159162
for param, default_value in self._get_disallowed_params():
@@ -241,8 +244,7 @@ def __init__(self, cfg: GenerateSolutionsConfig):
241244
)
242245

243246
def setup_llm(self):
244-
if (self.cfg.prompt_template is None
245-
and self.cfg.server["server_type"] not in ["openai", "vllm", "sglang"]):
247+
if self.cfg.prompt_template is None and self.cfg.server["server_type"] not in ["openai", "vllm", "sglang"]:
246248
with open_dict(self.cfg.server):
247249
self.cfg.server["server_type"] = "openai"
248250
self.cfg.server["model"] = "model"
@@ -261,19 +263,23 @@ def setup_prompt(self):
261263

262264
if self.cfg.prompt_format == "openai":
263265
return None
264-
265-
prompt = get_prompt(self.cfg.prompt_config, self.cfg.prompt_template, self.cfg.code_tags, examples_type=self.cfg.examples_type)
266+
267+
prompt = get_prompt(
268+
self.cfg.prompt_config, self.cfg.prompt_template, self.cfg.code_tags, examples_type=self.cfg.examples_type
269+
)
270+
if self.cfg.system_message is not None:
271+
prompt.config.system = self.cfg.system_message
266272
LOG.info("Prompt used: %s", prompt)
267273
return prompt
268274

269275
def log_example_prompt(self, data):
270276
data_point = deepcopy(data[0])
271277

272278
if self.cfg.prompt_format == "openai":
273-
#print the prompt in openai format
279+
# print the prompt in openai format
274280
LOG.info("Example prompt in OpenAI format: \nData dictionary: %s", data_point)
275281
return
276-
282+
277283
if self.cfg.multi_turn_key is None:
278284
LOG.info(
279285
"Example prompt:\nData dictionary: %s\nPrompt: %s", data_point, self.fill_prompt(data_point, data)
@@ -374,7 +380,7 @@ def fill_prompt(self, data_point, data):
374380
"""Passing in full data in case it's needed to fill the prompt in subclasses."""
375381
if self.cfg.prompt_format == "openai":
376382
return data_point["messages"]
377-
383+
378384
total_code_executions_in_prompt = self.cfg.total_code_executions_in_prompt
379385
if total_code_executions_in_prompt is not None:
380386
if isinstance(total_code_executions_in_prompt, (list, tuple)):
@@ -394,8 +400,7 @@ def llm_generate(self, data_points, data, is_async=False):
394400
generation_params = {
395401
"prompts": [self.fill_prompt(dp, data) for dp in data_points],
396402
"stop_phrases": combine_stop_phrases(
397-
self.prompt.stop_phrases if self.prompt is not None else None,
398-
self.extra_stop_phrases
403+
self.prompt.stop_phrases if self.prompt is not None else None, self.extra_stop_phrases
399404
),
400405
**asdict(self.cfg.inference),
401406
**self.extra_generate_params,

tests/test_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_error_on_extra_params():
5454
" ++eval_type=math "
5555
" ++eval_config.sandbox.sandbox_type=local "
5656
" ++eval_config.sandbox.sandbox_host=123 "
57+
" ++remove_thinking=false "
5758
)
5859
try:
5960
subprocess.run(cmd, shell=True, check=True, capture_output=True)

0 commit comments

Comments
 (0)