Skip to content

Commit 2b7fd92

Browse files
i-vainnKipok
andauthored
Azure OpenAI API and code execution (#511)
Signed-off-by: Igor Gitman <igitman@nvidia.com> Co-authored-by: Igor Gitman <igitman@nvidia.com>
1 parent b14e9c7 commit 2b7fd92

35 files changed

Lines changed: 360 additions & 197 deletions

.github/workflows/gpu_tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ jobs:
4747
run: |
4848
cd ${{ github.run_id }}
4949
nvidia-smi
50+
export DOCKER_CLIENT_TIMEOUT=120
5051
set -o pipefail # this will make sure next line returns non-0 exit code if tests fail
5152
./tests/gpu-tests/run_llama.sh
5253
- name: Cleanup
@@ -86,6 +87,7 @@ jobs:
8687
run: |
8788
cd ${{ github.run_id }}
8889
nvidia-smi
90+
export DOCKER_CLIENT_TIMEOUT=120
8991
set -o pipefail # this will make sure next line returns non-0 exit code if tests fail
9092
./tests/gpu-tests/run_qwen.sh
9193
- name: Cleanup
@@ -122,6 +124,7 @@ jobs:
122124
run: |
123125
cd ${{ github.run_id }}
124126
nvidia-smi
127+
export DOCKER_CLIENT_TIMEOUT=120
125128
set -o pipefail # this will make sure next line returns non-0 exit code if tests fail
126129
./tests/gpu-tests/run_rm.sh
127130
- name: Cleanup

docs/basics/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai
129129

130130
sandbox = get_sandbox() # localhost by default
131131
llm = get_code_execution_model(server_type="vllm", sandbox=sandbox)
132-
prompt = get_prompt('generic/default', 'llama3-instruct') # (1)!
132+
prompt = get_prompt('generic/default', 'llama3-instruct', code_tags='llama3') # (1)!
133133
prompt.config.system = ( # (2)!
134134
"Environment: ipython\n\n"
135135
"Use Python to solve this math problem."

docs/basics/prompt-format.md

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# Prompt utilities
22

3-
Our prompts are configured via two input yaml files: prompt template and prompt config.
3+
Our prompts are configured via three yaml files:
4+
5+
1. **Prompt template** - defines model-specific chat format and special tokens
6+
2. **Prompt config** - contains the actual prompt text with placeholders
7+
3. **Code tags** - specifies code formatting tokens, required for code execution
8+
49

510
## Prompt template
611

712
The template file defines model-specific special tokens, e.g. bos, turn tokens,
8-
user/assistant/system message, special tokens for code execution, etc. All of the
13+
user/assistant/system message, etc. All of the
914
templates that we support by default are available in
1015
[nemo_skills/prompt/template](https://github.com/NVIDIA/NeMo-Skills/tree/main/nemo_skills/prompt/template)
1116
folder. Here is an example template for
@@ -34,13 +39,6 @@ assistant_begin: "<|start_header_id|>assistant<|end_header_id|>\n\n"
3439
assistant_end: "<|eot_id|>"
3540

3641
stop_phrases: ["<|eot_id|>"]
37-
38-
# used to execute code within these tags
39-
code_begin: '<|python_tag|>'
40-
code_end: '<|eom_id|>'
41-
# used to extract the code output
42-
code_output_begin: '<|start_header_id|>ipython<|end_header_id|>'
43-
code_output_end: '<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
4442
```
4543
4644
You can specify a particular template with `++prompt_template=...`. If you don't add a .yaml extension (e.g.
@@ -96,22 +94,47 @@ prompt the `gsm8k_standard_few_shot` examples from
9694
[here](https://github.com/NVIDIA/NeMo-Skills/tree/main/nemo_skills/prompt/few_shot_examples/gsm8k.py) are used.
9795

9896

97+
## Code tags
98+
99+
Code tags define the special tokens that models use to mark executable code blocks and their output. Code tags are required when using code execution.
100+
All code tags that we support by default are available in
101+
[nemo_skills/prompt/code_tags](https://github.com/NVIDIA/NeMo-Skills/tree/main/nemo_skills/prompt/code_tags).
102+
103+
Here is an example code tags file for the [llama3](https://github.com/NVIDIA/NeMo-Skills/tree/main/nemo_skills/prompt/code_tags/llama3.yaml) family:
104+
105+
```yaml
106+
# Code tags for llama3 family models
107+
108+
# used to execute code within these tags
109+
code_begin: "<|python_tag|>"
110+
code_end: "<|eom_id|>"
111+
112+
# used to extract the code output
113+
code_output_begin: "<|start_header_id|>ipython<|end_header_id|>"
114+
code_output_end: "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
115+
116+
# how to post-process the captured output (choices: llama, qwen)
117+
code_output_format: "llama"
118+
```
119+
99120
## Prompt API
100121

101-
If you're running one of the pipeline scripts, you can control the prompt by using
122+
If you're running one of the pipeline scripts, you can control the prompt by using:
102123

103124
```bash
104125
++prompt_template=...
105126
++prompt_config=...
127+
++code_tags=...
106128
++examples_type=...
107129
```
108130

109-
If you're implementing a new script, you can use the following code to create a prompt and then use it
131+
If you're implementing a new script, you can use the following code to create a prompt and then use it:
110132

111133
```python
112134
from nemo_skills.prompt.utils import get_prompt
113135
114-
prompt = get_prompt('generic/math', 'llama3-instruct')
136+
# The third parameter is optional and only needed for code execution
137+
prompt = get_prompt('generic/math', 'llama3-instruct', code_tags='llama3')
115138
print(prompt.fill({'problem': "What's 2 + 2?"}))
116139
```
117140

docs/openmathreasoning1/evaluation.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ ns eval \
188188
--server_gpus=1 \
189189
--num_jobs=1 \
190190
--with_sandbox \
191-
++prompt_template=openmath-instruct \
191+
++code_tags=openmath \
192+
++prompt_template=qwen-instruct \
192193
++prompt_config=openmath/tir \
193194
++inference.tokens_to_generate=32768 \
194195
++inference.temperature=0.6 \
@@ -210,7 +211,8 @@ ns eval \
210211
--server_gpus=1 \
211212
--num_jobs=1 \
212213
--with_sandbox \
213-
++prompt_template=openmath-instruct \
214+
++code_tags=openmath \
215+
++prompt_template=qwen-instruct \
214216
++prompt_config=generic/math \
215217
++inference.tokens_to_generate=32768 \
216218
++inference.temperature=0.6 \

docs/openmathreasoning1/training.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ for inference_mode in ["cot", "tir", "genselect"]:
3333
dataset[inference_mode] = dataset[inference_mode].rename_column("problem", "input")
3434
dataset[inference_mode] = dataset[inference_mode].rename_column("generated_solution", "output")
3535

36+
code_tags = None
3637
if inference_mode == 'cot':
3738
prompt_config = 'generic/math'
3839
if inference_mode == 'tir':
3940
prompt_config = 'openmath/tir'
41+
code_tags = 'openmath'
4042
if inference_mode == 'genselect': # already formatted
4143
prompt_config = {'user': '{problem}'}
42-
prompt = get_prompt(prompt_config, 'qwen-instruct')
44+
prompt = get_prompt(prompt_config, 'qwen-instruct', code_tags)
4345
func = partial(apply_format, prompt=prompt, is_tir=(inference_mode == 'tir'))
4446
dataset[inference_mode] = dataset[inference_mode].map(func, num_proc=20)
4547

@@ -275,15 +277,17 @@ for inference_mode in ["cot", "tir", "genselect"]:
275277
dataset[inference_mode] = dataset[inference_mode].rename_column("problem", "input")
276278
dataset[inference_mode] = dataset[inference_mode].rename_column("generated_solution", "output")
277279

280+
code_tags = None
278281
if inference_mode == 'cot':
279282
prompt_config = 'generic/math'
280283
if inference_mode == 'tir':
281284
prompt_config = 'openmath/tir'
285+
code_tags = 'openmath'
282286
if inference_mode == 'genselect': # already formatted
283287
prompt_config = {'user': '{problem}'}
284288
func = partial(filter_func, inference_mode=inference_mode)
285289
dataset[inference_mode] = dataset[inference_mode].filter(func, num_proc=20)
286-
prompt = get_prompt(prompt_config, 'qwen-instruct')
290+
prompt = get_prompt(prompt_config, 'qwen-instruct', code_tags)
287291
func = partial(apply_format, prompt=prompt, is_tir=(inference_mode == 'tir'))
288292
dataset[inference_mode] = dataset[inference_mode].map(func, num_proc=20)
289293

nemo_skills/inference/chat_interface/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class AppConfig:
5050
# Prompt configuration
5151
base_prompt_config: str = "generic/math"
5252
code_prompt_config: str = "openmath/tir"
53-
prompt_template: str = "openmath-instruct"
53+
prompt_template: str = "qwen-instruct"
54+
code_tags: str = "openmath"
5455

5556
# Code-execution related
5657
initial_code_execution_state: bool = False

nemo_skills/inference/generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class GenerateSolutionsConfig:
5555
output_file: str # Where to save the generations
5656
prompt_config: str # How to format the data into prompts
5757
prompt_template: str | None = None # not required for OpenAI server
58+
code_tags: str | None = None # required when using code execution
5859
examples_type: str | None = None # to be able to customize few-shot examples
5960

6061
# Inference server configuration {server_params}
@@ -245,7 +246,7 @@ def setup_llm(self):
245246
return llm
246247

247248
def setup_prompt(self):
248-
prompt = get_prompt(self.cfg.prompt_config, self.cfg.prompt_template, examples_type=self.cfg.examples_type)
249+
prompt = get_prompt(self.cfg.prompt_config, self.cfg.prompt_template, self.cfg.code_tags, examples_type=self.cfg.examples_type)
249250
LOG.info("Prompt used: %s", prompt)
250251
return prompt
251252

nemo_skills/inference/server/code_execution_model.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _is_generation_cancelled(self, gen_id):
6161

6262
def _generate_single(
6363
self,
64-
prompt: str,
64+
prompt: str | list,
6565
code_begin: str,
6666
code_end: str,
6767
code_output_begin: str,
@@ -81,8 +81,9 @@ def _generate_single(
8181
max_code_executions: int | None = None, # if not None, will override self.config.max_code_executions
8282
stream: bool = False,
8383
):
84-
if not isinstance(prompt, str):
85-
raise NotImplementedError("OpenAI API is not supported yet.")
84+
# Handle OpenAI-style dictionary prompts
85+
is_openai_format = not isinstance(prompt, str)
86+
8687
if top_logprobs is not None: # TODO: add this
8788
raise NotImplementedError("top_logprobs is not supported yet.")
8889

@@ -106,18 +107,20 @@ def _generate_single(
106107
max_code_executions=max_code_executions,
107108
)
108109

109-
if stop_phrases is None:
110-
stop_phrases = []
111-
112110
effective_max_code_executions = self.config.max_code_executions
113111
if max_code_executions is not None:
114112
effective_max_code_executions = max_code_executions
115113

116114
# making a copy of prompts to not corrupt original data
117-
new_prompt = copy.deepcopy(prompt)
115+
if is_openai_format:
116+
new_prompt = copy.deepcopy(prompt)
117+
else:
118+
new_prompt = copy.deepcopy(prompt)
118119

119120
start_time = int(time.time())
120121

122+
stop_phrases = stop_phrases or []
123+
121124
request = {
122125
"prompt": new_prompt,
123126
"tokens_to_generate": tokens_to_generate,
@@ -176,7 +179,19 @@ def _generate_single(
176179
output, num_generated_tokens = output_dict['generation'], output_dict.get('num_generated_tokens', 0)
177180
# no need to do anything with this as the code below should just exit, so that's only for logging
178181
stopped_on_repetition = output_dict.get('stopped_on_repetition', False)
179-
request['prompt'] += output
182+
183+
# openai don't show what stop word was triggered, so we assume that it was `code_end`
184+
# if there's an unfinished code block
185+
if is_openai_format and output_dict.get('finish_reason') == 'stop':
186+
if output.count(code_end) + 1 == output.count(code_begin):
187+
output += code_end
188+
# Update the prompt based on format
189+
if is_openai_format:
190+
request['prompt'].append({'role': 'assistant', 'content': output})
191+
request['prompt'].append({'role': 'user', 'content': "continue"})
192+
else:
193+
request['prompt'] += output
194+
180195
# if it's the extra iteration, we don't execute the code block and just finish
181196

182197
if generation_index == effective_max_code_executions:
@@ -204,17 +219,28 @@ def _generate_single(
204219
if self.config.add_remaining_code_executions:
205220
remaining_code_executions = effective_max_code_executions - generation_index - 1
206221
# adding code output to the prompt
207-
request['prompt'] += format_code_output(
222+
code_output = format_code_output(
208223
execution_dict, code_output_begin, code_output_end, code_output_format, remaining_code_executions
209224
)
225+
226+
if is_openai_format:
227+
request['prompt'][-2]['content'] += code_output
228+
else:
229+
request['prompt'] += code_output
230+
210231
code_execution_time += int(time.time() - code_execution_time_start)
211232
code_rounds_executed += 1
212233
else: # if no code was generated, we need to finish
213234
break
214235

215-
# removing original prompt
236+
# removing original prompt and returning the generation
237+
if is_openai_format:
238+
generation = "\n".join(msg['content'] for msg in request['prompt'] if msg['role'] == 'assistant')
239+
else:
240+
generation = request['prompt'][len(prompt):]
241+
216242
return {
217-
'generation': request['prompt'][len(prompt) :],
243+
'generation': generation,
218244
'code_rounds_executed': code_rounds_executed,
219245
'num_generated_tokens': total_num_generated_tokens,
220246
'generation_time': generation_time,
@@ -433,6 +459,9 @@ def _stream_single(
433459
"""
434460
Helper method, that implements streaming generation.
435461
"""
462+
# Handle OpenAI-style dictionary prompts
463+
is_openai_format = not isinstance(prompt, str)
464+
436465
effective_max_code_executions = self.config.max_code_executions
437466
if max_code_executions is not None:
438467
effective_max_code_executions = max_code_executions
@@ -452,7 +481,7 @@ def _stream_single(
452481
'stream': True,
453482
}
454483

455-
current_full_prompt = prompt
484+
current_full_prompt = copy.deepcopy(prompt)
456485
session_id = None # For sandbox state continuity
457486
for generation_index in range(effective_max_code_executions + 1):
458487
model_token_iterator = self.model._generate_single(prompt=current_full_prompt, **request)
@@ -470,7 +499,18 @@ def _stream_single(
470499
if not current_output_segment:
471500
break
472501

473-
current_full_prompt += current_output_segment
502+
# openai don't show what stop word was triggered, so we assume that it was `code_end`
503+
# if there's an unfinished code block
504+
if is_openai_format and chunk.get('finish_reason') == 'stop':
505+
if current_output_segment.count(code_end) + 1 == current_output_segment.count(code_begin):
506+
current_output_segment += code_end
507+
508+
# Update the prompt based on format
509+
if is_openai_format:
510+
current_full_prompt.append({'role': 'assistant', 'content': current_output_segment})
511+
current_full_prompt.append({'role': 'user', 'content': "continue"})
512+
else:
513+
current_full_prompt += current_output_segment
474514

475515
if generation_index == effective_max_code_executions:
476516
# This was the last iteration, intended for final text generation after all code executions.
@@ -496,7 +536,12 @@ def _stream_single(
496536
)
497537

498538
yield {'generation': formatted_code_output} # Yield the entire formatted code output as one chunk
499-
current_full_prompt += formatted_code_output # Append executed code's output to the prompt
539+
540+
# Append executed code's output to the prompt
541+
if is_openai_format:
542+
current_full_prompt[-2]['content'] += formatted_code_output
543+
else:
544+
current_full_prompt += formatted_code_output
500545
else:
501546
break
502547

0 commit comments

Comments
 (0)