Skip to content

Commit 79ab125

Browse files
gwarmstrongKipok
andauthored
Update nemo rl grpo templating (#501)
Co-authored-by: Igor Gitman <igitman@nvidia.com>
1 parent 3f2445d commit 79ab125

7 files changed

Lines changed: 271 additions & 210 deletions

File tree

dockerfiles/Dockerfile.nemo-rl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,5 @@ ENV NVIDIA_BUILD_ID=${NVIDIA_BUILD_ID:-<unknown>}
5858
ENV NVIDIA_BUILD_REF=${NVIDIA_BUILD_REF:-<unknown>}
5959
LABEL com.nvidia.build.id="${NVIDIA_BUILD_ID}"
6060
LABEL com.nvidia.build.ref="${NVIDIA_BUILD_REF}"
61+
62+
RUN git clone https://github.com/NVIDIA/NeMo-Skills.git /opt/NeMo-Skills && cd /opt/NeMo-Skills && uv pip install .

nemo_skills/prompt/utils.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dataclasses import asdict, field
2020
from pathlib import Path
2121
from typing import Any, Dict, List, Optional
22+
from itertools import zip_longest
2223

2324
import yaml
2425

@@ -217,6 +218,7 @@ def fill(
217218
prefix_generation_to_response: bool = False,
218219
continue_prefix_generation: bool = False,
219220
multi_turn_key: str | None = None,
221+
return_templated_dict: bool = False,
220222
) -> str | List[dict]:
221223
"""
222224
Fills the prompt with the input_dict.
@@ -230,6 +232,9 @@ def fill(
230232
multi_turn_key: If specified, will read the list from input_dict[multi_turn_key]
231233
and use it to construct the prompt. You input_dict should also have "assistant" key in all
232234
turns except last containing assistant reply.
235+
return_templated_dict: Indicates whether to return a messages list where the template is used
236+
to fill the prompt. If so, a list of dicts with 'role' and 'content' keys will be returned.
237+
In this case the final user and assistant messages will include special tokens.
233238
234239
Returns:
235240
The filled prompt - either a string or a list of dictionaries.
@@ -243,36 +248,61 @@ def fill(
243248

244249
if self.config.template:
245250
if multi_turn_key is None:
246-
prompt_string = self.SYSTEM_FORMAT.format(
251+
prompt_string = (system_string := self.SYSTEM_FORMAT.format(
247252
system=self.config.system.format(**input_dict), **asdict(self.config.template)
248-
)
249-
prompt_string += self.TURN_BEGIN_FORMAT.format(
253+
))
254+
prompt_string += (user_string := self.TURN_BEGIN_FORMAT.format(
250255
user=self.build_user_message(input_dict), **asdict(self.config.template)
251-
)
256+
))
257+
user_strings = [user_string]
258+
assistant_strings = []
252259
if generation:
253260
# Generation can be part of the input in cases such as reward models
254261
if continue_prefix_generation:
255262
# Append generation without the closing tag.
256-
prompt_string += generation
263+
prompt_string += (assistant_string := generation)
257264
else:
258-
prompt_string += self.TURN_END_FORMAT.format(
265+
prompt_string += (assistant_string := self.TURN_END_FORMAT.format(
259266
assistant=generation, **asdict(self.config.template)
260-
)
267+
))
268+
assistant_strings.append(assistant_string)
269+
261270
else:
262-
prompt_string = self.SYSTEM_FORMAT.format(
271+
prompt_string = (system_string := self.SYSTEM_FORMAT.format(
263272
system=self.config.system.format(**input_dict), **asdict(self.config.template)
264-
)
273+
))
274+
user_strings = []
275+
assistant_strings = []
265276
for turn in input_dict[multi_turn_key][:-1]:
266-
prompt_string += self.TURN_BEGIN_FORMAT.format(
277+
prompt_string += (user_string := self.TURN_BEGIN_FORMAT.format(
267278
user=self.build_user_message(turn), **asdict(self.config.template)
268-
)
269-
prompt_string += self.TURN_END_FORMAT.format(
279+
))
280+
user_strings.append(user_string)
281+
prompt_string += (assistant_string := self.TURN_END_FORMAT.format(
270282
assistant=turn["assistant"], **asdict(self.config.template)
271-
)
272-
prompt_string += self.TURN_BEGIN_FORMAT.format(
283+
))
284+
assistant_strings.append(assistant_string)
285+
286+
prompt_string += (user_string := self.TURN_BEGIN_FORMAT.format(
273287
user=self.build_user_message(input_dict[multi_turn_key][-1]), **asdict(self.config.template)
274-
)
288+
))
289+
user_strings.append(user_string)
275290
prompt_string += generation
291+
if generation:
292+
assistant_strings.append(generation)
293+
294+
if return_templated_dict:
295+
messages = [
296+
{'role': 'system', 'content': system_string},
297+
]
298+
299+
for user_msg, assistant_msg in zip_longest(user_strings, assistant_strings, fillvalue=None):
300+
if user_msg is not None:
301+
messages.append({'role': 'user', 'content': user_msg})
302+
if assistant_msg is not None:
303+
messages.append({'role': 'assistant', 'content': assistant_msg})
304+
305+
return messages
276306
return prompt_string
277307
else:
278308
if multi_turn_key is None:

nemo_skills/training/nemo_rl/configs/grpo.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,14 @@ policy:
110110

111111
data:
112112
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
113-
prompt_file: "/nemo_run/code/nemo_skills/training/nemo_rl/prompts/math.txt"
114-
system_prompt_file: null
115-
dataset_name: "OpenMathInstruct-2"
113+
prompt:
114+
prompt_config: ???
115+
prompt_template: ???
116+
examples_type: null
117+
config_dir: null
118+
template_dir: null
119+
train_data_path: ???
120+
val_data_path: ???
116121

117122
env:
118123
math:

0 commit comments

Comments
 (0)