-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathdefaults.yaml
More file actions
97 lines (82 loc) · 2.71 KB
/
defaults.yaml
File metadata and controls
97 lines (82 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Training config
model_tag: ??? # E.g., nvidia/esm2_t6_8M_UR50D, facebook/esm2_t6_8M_UR50D, or a local path (e.g ./example_8m_checkpoint)
num_train_steps: ???
use_meta_device: true
# Whether to wrap the model in torch.compile. Note, this is currently not supported with mfsdp (BIONEMO-2977).
# We leave this off by default since we don't see much of a performance improvement with TE layers.
use_torch_compile: false
cp_size: 1
use_sequence_packing: false
dataset:
tokenizer_name: ${model_tag}
tokenizer_revision: null
micro_batch_size: ???
num_workers: 1
max_seq_length: 1024
mlm_probability: 0.15
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
load_dataset_kwargs:
path: "nvidia/esm2_uniref_pretraining_data"
split: "train"
streaming: True
# WandB config
wandb_init_args:
name: ???
# mFSDP config
fully_shard_kwargs:
zero_dp_strategy: "optim_grads_params"
calculate_per_token_loss: false
init_model_with_meta_device: ${use_meta_device}
check_for_nan_in_grad: true
grad_reduce_in_fp32: false
preserve_fp32_weights: true
overlap_grad_reduce: true
overlap_param_gather: true
sync_model_each_microbatch: true
average_in_collective: false
# TransformerEngine FP8 config. See
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on
# supported formats.
fp8_config:
enabled: false
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
fp4_config:
enabled: false
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
fp4_format: "E2M1"
fp4_recipe_kwargs: {}
fp4_model_init_kwargs:
enabled: false # If this is set to true, fp4_config.enabled must also be set to true.
# Optimizer config
adamw_kwargs:
lr: 4e-4
fused: true
betas: [0.9, 0.98]
eps: 1e-8
weight_decay: 0.01
# Learning rate scheduler config
lr_scheduler_kwargs:
num_warmup_steps: 2_000
num_training_steps: 500_000
# Checkpoint config
checkpoint:
ckpt_dir: ???
save_final_model: true
resume_from_checkpoint: true
save_every_n_steps: 1_000
max_checkpoints: 5 # Keep only the latest 5 checkpoints
async_save: true # Whether to save the checkpoint asynchronously, currently only supported with FSDP2.
logger:
frequency: 100
quant_stats_config:
enabled: false
quant_stats_file: ./fp8_debugging_stats.yaml
quant_log_dir: ./log_quant_stats
# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
fp8_layers: null
fp4_layers: null
use_fp32_master_weights: null