Skip to content

Commit e9b27f7

Browse files
authored
add config overrides for llama recipe (#1343)
Updates the hydra config to let us pass model size information via the hydra configs Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent b9b916e commit e9b27f7

File tree

6 files changed

+60
-67
lines changed

6 files changed

+60
-67
lines changed

bionemo-recipes/recipes/llama3_native_te/example_checkpoint/config.json

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,25 @@
1212
"AutoModelForSequenceClassification": "llama3_nv.NVLlamaForSequenceClassification",
1313
"AutoModelForTokenClassification": "llama3_nv.NVLlamaForTokenClassification"
1414
},
15+
"bos_token_id": 128000,
1516
"dtype": "bfloat16",
17+
"eos_token_id": [
18+
128001,
19+
128008,
20+
128009
21+
],
1622
"head_dim": 64,
1723
"hidden_act": "silu",
18-
"hidden_size": 384,
24+
"hidden_size": 2048,
1925
"initializer_range": 0.02,
20-
"intermediate_size": 1536,
26+
"intermediate_size": 8192,
2127
"max_position_embeddings": 131072,
2228
"mlp_bias": false,
2329
"model_type": "llama",
24-
"num_attention_heads": 6,
25-
"num_hidden_layers": 2,
26-
"num_key_value_heads": 6,
30+
"num_attention_heads": 32,
31+
"num_hidden_layers": 16,
32+
"num_key_value_heads": 8,
33+
"pretraining_tp": 1,
2734
"rms_norm_eps": 1e-05,
2835
"rope_scaling": {
2936
"factor": 32.0,
@@ -33,12 +40,8 @@
3340
"rope_type": "llama3"
3441
},
3542
"rope_theta": 500000.0,
36-
"tie_word_embeddings": false,
43+
"tie_word_embeddings": true,
3744
"transformers_version": "4.57.1",
3845
"use_cache": true,
39-
"vocab_size": 256,
40-
"bos_token_id": 2,
41-
"eos_token_id": 0,
42-
"pad_token_id": 1,
43-
"attn_input_format": "bshd"
46+
"vocab_size": 128256
4447
}

bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_convergence.yaml

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,65 +11,50 @@ defaults:
1111
# Use tiny Llama config for fast convergence testing
1212
model_tag: ./example_checkpoint
1313

14-
# Training steps - enough to see convergence on small dataset
15-
num_train_steps: 1000
14+
num_train_steps: 270_000
1615

17-
# Dataset configuration - use small test dataset
1816
dataset:
19-
tokenizer_path: ./example_checkpoint # Tokenizer included in checkpoint directory
20-
micro_batch_size: 1 # Conservative for single GPU
21-
num_workers: 2
22-
max_seq_length: 8192 # Full Llama3 context length
23-
stride: 400 # 400bp overlap for 8K context
24-
buffer_size: 10_000 # Smaller buffer for faster iteration
25-
use_lazy_tokenization: true
26-
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
17+
micro_batch_size: 1 # Conservative for single GPU
2718
load_dataset_kwargs:
28-
path: "parquet"
29-
data_files: "genomic_sequences_2mb.parquet" # 2MB convergence test data in recipe directory
19+
path: "arcinstitute/opengenome2"
20+
data_dir: "json/pretraining_or_both_phases"
3021
split: "train"
31-
streaming: true # Use streaming to avoid loading entire dataset into memory
22+
streaming: true # Use streaming to avoid loading entire dataset into memory
3223

33-
# Optimizer - higher LR for faster convergence on small model
3424
adamw_kwargs:
35-
lr: 5e-4 # Higher than default for faster convergence
25+
lr: 5e-4
3626
fused: true
3727
betas: [0.9, 0.98]
3828
eps: 1e-8
3929
weight_decay: 0.01
4030

41-
# Learning rate scheduler
4231
lr_scheduler_kwargs:
43-
num_warmup_steps: 100 # Quick warmup (10% of training)
44-
num_training_steps: 1000
32+
num_warmup_steps: 20_000
33+
num_training_steps: 500_000
4534

46-
# Checkpoint configuration - disabled for fast convergence testing
4735
checkpoint:
48-
ckpt_dir: null # No checkpoints
49-
save_final_model: false # Don't save final model
50-
resume_from_checkpoint: false # Start fresh for convergence test
51-
save_every_n_steps: null # No intermediate checkpoints
36+
ckpt_dir: null # No checkpoints
37+
save_final_model: false # Don't save final model
38+
resume_from_checkpoint: false # Start fresh for convergence test
39+
save_every_n_steps: null # No intermediate checkpoints
5240

53-
# Logging - frequent logging to track convergence
5441
logger:
55-
frequency: 10 # Log every 10 steps
42+
frequency: 100
5643

5744
# WandB configuration
5845
wandb_init_args:
5946
project: "llama3-genomic-convergence"
6047
name: "tiny-llama-convergence-test"
61-
mode: "online" # Online mode for real-time dashboard
48+
mode: "online"
6249
tags:
6350
- convergence-test
6451
- tiny-model
6552
- 1M-params
6653
- 8192-context
6754

68-
# Meta device and torch compile
6955
use_meta_device: false
70-
use_torch_compile: false # Disable for debugging
56+
use_torch_compile: false
7157

72-
# FP8 configuration - disabled for convergence testing
7358
fp8_config:
7459
enabled: false
7560
fp8_recipe: transformer_engine.common.recipe.DelayedScaling

bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity.yaml

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,39 @@ defaults:
33
- _self_
44

55
# Training config
6-
model_tag: ./example_checkpoint # Use tiny Llama config for testing (4 layers, 384 hidden, ~9.6M params)
6+
model_tag: ./example_checkpoint # Use tiny Llama config for testing (4 layers, 384 hidden, ~9.6M params)
7+
8+
config_kwargs:
9+
num_hidden_layers: 2
10+
hidden_size: 384
11+
intermediate_size: 1536
12+
num_attention_heads: 6
13+
num_key_value_heads: 6
14+
715
num_train_steps: 250
816

917
# We want this on in CI/CD to validate that the script runs successfully with torch.compile.
10-
use_torch_compile: false # Disable for faster startup during testing
18+
use_torch_compile: true # Disable for faster startup during testing
1119

1220
dataset:
13-
tokenizer_path: ./example_checkpoint # Tokenizer included in checkpoint directory
14-
micro_batch_size: 1 # Small batch size for limited GPU memory
15-
num_workers: 1
16-
max_seq_length: 1024 # Smaller window for testing
17-
stride: 100 # Smaller stride for testing
18-
buffer_size: 10_000 # Smaller buffer for testing
19-
use_lazy_tokenization: true
20-
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
21+
micro_batch_size: 1 # Small batch size for limited GPU memory
2122
load_dataset_kwargs:
2223
path: "parquet"
2324
split: "train"
2425
data_files: "test_genomic_sequences.parquet" # Use local test file in recipe directory
25-
26+
streaming: True
2627

2728
# WandB config
2829
wandb_init_args:
2930
name: "llama3_8B_genomic_sanity"
3031
mode: "offline"
31-
project: null # Set to null by default, override with +wandb_init_args.project=your-project
3232

3333
# Learning rate scheduler config
3434
lr_scheduler_kwargs:
35-
num_warmup_steps: 10 # Shorter warmup for quick testing
36-
num_training_steps: 250 # Match num_train_steps
35+
num_warmup_steps: 10 # Shorter warmup for quick testing
3736

3837
checkpoint:
3938
ckpt_dir: null
40-
resume_from_checkpoint: true
41-
save_every_n_steps: 50
4239
save_final_model: false
4340

4441
logger:

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Training config
2-
model_tag: ??? # E.g., meta-llama/Meta-Llama-3-8B or a local path
2+
model_tag: ??? # E.g., meta-llama/Llama-3.2-1B or a local path
3+
config_kwargs: # Arguments to pass to the AutoConfig.from_pretrained method
4+
trust_remote_code: true
5+
vocab_size: 256 # Overrides to the default config that comes from meta-llama/Llama-3.2-1B
6+
tie_word_embeddings: false
7+
eos_token_id: 0
8+
pad_token_id: 1
9+
bos_token_id: 2
10+
311
num_train_steps: ???
412

513
# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
@@ -14,23 +22,23 @@ use_torch_compile: false
1422
use_gradient_checkpointing: false
1523

1624
dataset:
17-
tokenizer_path: ./example_checkpoint # Set to the path of your tokenizer (e.g., ./example_checkpoint)
25+
tokenizer_path: ${model_tag} # Set to the path of your tokenizer (e.g., ./example_checkpoint)
1826
micro_batch_size: 8
1927
num_workers: 1
20-
max_seq_length: 8192 # Window size for genomic sequences
21-
stride: 200 # Overlap for windowing
22-
buffer_size: 500_000 # Shuffle buffer size
28+
max_seq_length: 8192 # Window size for genomic sequences
29+
stride: 200 # Overlap for windowing
30+
buffer_size: 500_000 # Shuffle buffer size
2331
use_lazy_tokenization: true
24-
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
32+
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
2533
load_dataset_kwargs:
26-
path: "parquet"
34+
path: ???
2735
split: "train"
2836
streaming: True
2937

3038
# WandB config
3139
wandb_init_args:
3240
name: ???
33-
project: null # Optional: set to your wandb project name
41+
project: null # Optional: set to your wandb project name
3442

3543
# mFSDP config
3644
fully_shard_kwargs:

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main(args: DictConfig) -> float | None:
5959
)
6060

6161
# Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
62-
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16)
62+
config = AutoConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, **args.config_kwargs)
6363
# Use SDPA (Scaled Dot-Product Attention) to avoid materializing large causal masks
6464
# config.attn_implementation = "sdpa"
6565

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main(args: DictConfig) -> float | None: # noqa: C901
6666
)
6767

6868
# Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
69-
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16)
69+
config = AutoConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, **args.config_kwargs)
7070
# Use SDPA (Scaled Dot-Product Attention) to avoid materializing large causal masks
7171
# config.attn_implementation = "sdpa"
7272

0 commit comments

Comments
 (0)