Skip to content

Commit 47e13f8

Browse files
committed
use recipe import system
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent c9098b6 commit 47e13f8

8 files changed

Lines changed: 213 additions & 89 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default DFlashConfig values for DFlash training. Imported into the `dflash:`
5+
# section of recipes. ``dflash_mask_token_id`` is intentionally omitted; per-model
6+
# recipes should provide it explicitly, and main.py falls back to
7+
# ``tokenizer.mask_token_id`` when neither does.
8+
9+
# modelopt-schema: modelopt.torch.speculative.config.DFlashConfig
10+
dflash_block_size: 8
11+
dflash_num_anchors: 512
12+
dflash_use_torch_compile: false
13+
dflash_self_logit_distillation: true
14+
dflash_loss_decay_factor: 4.0
15+
dflash_architecture_config:
16+
num_hidden_layers: 5
17+
# mask_token_id: auto-detected from model vocab (override for specific models)
18+
# sliding_window and layer_types are inherited from base model config automatically
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default `training:` section values for DFlash training. Imported into the
5+
# `training:` section of recipes. HF trainer fields flow through SpecTrainingArgs
6+
# via ``extra='allow'`` and are re-validated by transformers.TrainingArguments
7+
# in main.py.
8+
9+
# modelopt-schema: modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments
10+
11+
# --- commonly modified ---
12+
output_dir:
13+
num_train_epochs: 10
14+
per_device_train_batch_size: 1
15+
learning_rate: 6.0e-4
16+
warmup_steps: 100
17+
training_seq_len: 4096
18+
logging_steps: 100
19+
save_steps: 5000
20+
cp_size: 1
21+
dp_shard_size: 1
22+
disable_tqdm: true
23+
estimate_ar: false
24+
ar_validate_steps: 0
25+
answer_only_loss: true
26+
27+
# --- rarely modified ---
28+
do_eval: false
29+
lr_scheduler_type: linear
30+
save_strategy: steps
31+
weight_decay: 0.0
32+
dataloader_drop_last: true
33+
bf16: true
34+
tf32: true
35+
remove_unused_columns: false
36+
ddp_find_unused_parameters: true
37+
ddp_timeout: 1800
38+
report_to: tensorboard
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Default EagleConfig values for EAGLE3 training. Imported into the `eagle:` section of recipes.
17+
# eagle_offline is derived from data.offline_data_path; do not set here.
18+
19+
# modelopt-schema: modelopt.torch.speculative.config.EagleConfig
20+
eagle_decoder_type: llama
21+
eagle_ttt_steps: 3
22+
eagle_mix_hidden_states: false
23+
eagle_use_torch_compile: true
24+
eagle_self_logit_distillation: true
25+
eagle_freeze_base_model: true
26+
eagle_loss_decay_factor: 0.9
27+
eagle_hidden_state_distillation: false
28+
eagle_reuse_base_decoder: false
29+
eagle_report_acc: true
30+
eagle_enable_nvtx: false
31+
# Rope scaling: disable during training (default_config.py uses rope_type=default),
32+
# inject YaRN during export for long-context inference.
33+
eagle_export_rope_scaling:
34+
rope_type: yarn
35+
factor: 32.0
36+
original_max_position_embeddings: 2048
37+
# overwrite to modelopt/torch/speculative/eagle/default_config.py
38+
eagle_architecture_config: {}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default `training:` section values for EAGLE3 training. Imported into the
5+
# `training:` section of recipes. HF trainer fields flow through SpecTrainingArgs
6+
# via ``extra='allow'`` and are re-validated by transformers.TrainingArguments
7+
# in main.py.
8+
9+
# modelopt-schema: modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments
10+
11+
# --- commonly modified ---
12+
output_dir:
13+
num_train_epochs: 1
14+
per_device_train_batch_size: 1
15+
learning_rate: 1.0e-4
16+
warmup_steps: 1000
17+
training_seq_len: 2048
18+
logging_steps: 100
19+
save_steps: 8192
20+
cp_size: 1
21+
disable_tqdm: false
22+
estimate_ar: false
23+
ar_validate_steps: -1
24+
answer_only_loss: false
25+
26+
# --- rarely modified ---
27+
do_eval: false
28+
lr_scheduler_type: linear
29+
save_strategy: steps
30+
weight_decay: 0.0
31+
dataloader_drop_last: true
32+
bf16: true
33+
tf32: true
34+
remove_unused_columns: false
Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
# DFlash speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI.
1+
# DFlash speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI
2+
# or by importing this file from a per-model recipe in modelopt_recipes/models/.
23

34
metadata:
45
recipe_type: speculative_dflash
56
description: DFlash training recipe (model/data/training/dflash bundled).
67

7-
# maps to ModelArguments (main.py)
8+
imports:
9+
dflash_default: configs/speculative_decoding/dflash/default
10+
dflash_training_default: configs/speculative_decoding/dflash/training_default
11+
12+
# maps to ModelArguments
813
model:
914
model_name_or_path:
1015
trust_remote_code: false
1116
use_fake_base_for_offline: false
1217

13-
# maps to DataArguments (main.py)
18+
# maps to DataArguments
1419
data:
1520
data_path:
1621
offline_data_path:
@@ -19,45 +24,13 @@ data:
1924
# Templates are in modelopt_recipes/general/speculative_decoding/chat_templates/
2025
chat_template:
2126

22-
# maps to TrainingArguments (main.py)
27+
# maps to TrainingArguments
2328
training:
24-
# --- commonly modified ---
25-
output_dir:
26-
num_train_epochs: 10
27-
per_device_train_batch_size: 1
28-
learning_rate: 6.0e-4
29-
warmup_steps: 100
30-
training_seq_len: 4096
31-
logging_steps: 100
32-
save_steps: 5000
33-
cp_size: 1
34-
dp_shard_size: 1
35-
disable_tqdm: true
36-
estimate_ar: false
37-
ar_validate_steps: 0
38-
answer_only_loss: true
39-
40-
# --- rarely modified ---
41-
do_eval: false
42-
lr_scheduler_type: linear
43-
save_strategy: steps
44-
weight_decay: 0.0
45-
dataloader_drop_last: true
46-
bf16: true
47-
tf32: true
48-
remove_unused_columns: false
49-
ddp_find_unused_parameters: true
50-
ddp_timeout: 1800
51-
report_to: tensorboard
29+
$import: dflash_training_default
5230

5331
# maps to DFlashConfig (modelopt/torch/speculative/config.py).
32+
# Per-model recipes should also set ``dflash_mask_token_id``; otherwise main.py
33+
# falls back to ``tokenizer.mask_token_id``, and DFlashConfig raises if neither
34+
# source provides one.
5435
dflash:
55-
dflash_block_size: 8
56-
dflash_num_anchors: 512
57-
dflash_use_torch_compile: false
58-
dflash_self_logit_distillation: true
59-
dflash_loss_decay_factor: 4.0
60-
dflash_architecture_config:
61-
num_hidden_layers: 5
62-
# mask_token_id: auto-detected from model vocab (override for specific models)
63-
# sliding_window and layer_types are inherited from base model config automatically
36+
$import: dflash_default
Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,32 @@
1-
# EAGLE3 speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI.
1+
# EAGLE3 speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI
2+
# or by importing this file from a per-model recipe in modelopt_recipes/models/.
23

34
metadata:
45
recipe_type: speculative_eagle
56
description: EAGLE3 training recipe (model/data/training/eagle bundled).
67

7-
# maps to ModelArguments (main.py)
8+
imports:
9+
eagle_default: configs/speculative_decoding/eagle/default
10+
eagle_training_default: configs/speculative_decoding/eagle/training_default
11+
12+
# maps to ModelArguments
813
model:
914
model_name_or_path:
1015
trust_remote_code: false
1116
use_fake_base_for_offline: false
1217

13-
# maps to DataArguments (main.py)
18+
# maps to DataArguments
1419
data:
1520
data_path: input_conversations/train.jsonl
1621
offline_data_path:
1722
draft_vocab_cache:
1823
vlm_img_dir:
1924
vlm_processor:
2025

21-
# maps to TrainingArguments (main.py)
26+
# maps to TrainingArguments
2227
training:
23-
# --- commonly modified ---
24-
output_dir:
25-
num_train_epochs: 1
26-
per_device_train_batch_size: 1
27-
learning_rate: 1.0e-4
28-
warmup_steps: 1000
29-
training_seq_len: 2048
30-
logging_steps: 100
31-
save_steps: 8192
32-
cp_size: 1
33-
disable_tqdm: false
34-
estimate_ar: false
35-
ar_validate_steps: -1
36-
answer_only_loss: false
37-
38-
# --- rarely modified ---
39-
do_eval: false
40-
lr_scheduler_type: linear
41-
save_strategy: steps
42-
weight_decay: 0.0
43-
dataloader_drop_last: true
44-
bf16: true
45-
tf32: true
46-
remove_unused_columns: false
28+
$import: eagle_training_default
4729

4830
# maps to EagleConfig (modelopt/torch/speculative/config.py).
4931
eagle:
50-
# eagle_offline is derived from data.offline_data_path; do not set here.
51-
eagle_decoder_type: llama
52-
eagle_ttt_steps: 3
53-
eagle_mix_hidden_states: false
54-
eagle_use_torch_compile: true
55-
eagle_self_logit_distillation: true
56-
eagle_freeze_base_model: true
57-
eagle_loss_decay_factor: 0.9
58-
eagle_hidden_state_distillation: false
59-
eagle_reuse_base_decoder: false
60-
eagle_report_acc: true
61-
eagle_enable_nvtx: false
62-
# Rope scaling: disable during training (default_config.py uses rope_type=default),
63-
# inject YaRN during export for long-context inference.
64-
eagle_export_rope_scaling:
65-
rope_type: yarn
66-
factor: 32.0
67-
original_max_position_embeddings: 2048
68-
# overwrite to modelopt/torch/speculative/eagle/default_config.py
69-
eagle_architecture_config: {}
32+
$import: eagle_default
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Per-model DFlash offline training recipe for Kimi-K2.5.
5+
6+
metadata:
7+
recipe_type: speculative_dflash
8+
description: DFlash offline training recipe for Kimi-K2.5.
9+
10+
imports:
11+
dflash_default: configs/speculative_decoding/dflash/default
12+
dflash_training_default: configs/speculative_decoding/dflash/training_default
13+
14+
model:
15+
model_name_or_path: moonshotai/Kimi-K2.5
16+
trust_remote_code: true
17+
use_fake_base_for_offline: true
18+
19+
data:
20+
offline_data_path: <path to offline data>
21+
22+
training:
23+
$import: dflash_training_default
24+
output_dir: ckpts/kimi-k25-dflash
25+
26+
dflash:
27+
$import: dflash_default
28+
# If unset, main.py falls back to tokenizer.mask_token_id; DFlashConfig
29+
# raises if neither this field nor the tokenizer provides one.
30+
# dflash_mask_token_id:
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Per-model EAGLE3 offline training recipe for Kimi-K2.5.
5+
# Mirrors examples/speculative_decoding/scripts/train_kimi_k25_offline.sh.
6+
7+
metadata:
8+
recipe_type: speculative_eagle
9+
description: EAGLE3 offline training recipe for Kimi-K2.5.
10+
11+
imports:
12+
eagle_default: configs/speculative_decoding/eagle/default
13+
eagle_training_default: configs/speculative_decoding/eagle/training_default
14+
15+
model:
16+
model_name_or_path: moonshotai/Kimi-K2.5
17+
trust_remote_code: true
18+
use_fake_base_for_offline: true
19+
20+
data:
21+
offline_data_path: <path to offline data>
22+
23+
training:
24+
$import: eagle_training_default
25+
output_dir: ckpts/kimi-k25-eagle3
26+
training_seq_len: 4096
27+
28+
eagle:
29+
$import: eagle_default
30+
eagle_decoder_type: kimik2

0 commit comments

Comments
 (0)