Skip to content

Commit bf10c4d

Browse files
authored
Pstjohn/llama3 lingua fixes (#1364)
Some WIP quality of life improvements for the llama3 recipe * renames wandb_init_args to just `wandb` * adds the option for pytorch kineto profiling through the perf logger, with wandb upload * simplifies some of the dataloader column arguments Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 1bfd599 commit bf10c4d

File tree

11 files changed

+262
-105
lines changed

11 files changed

+262
-105
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# syntax=docker/dockerfile:1.4
2+
FROM nvcr.io/nvidia/pytorch:25.11-py3
3+
4+
RUN --mount=type=cache,target=/root/.cache/pip \
5+
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
6+
PIP_CONSTRAINT= pip install -r /requirements.txt
7+
8+
WORKDIR /workspace/bionemo
9+
COPY . .

bionemo-recipes/recipes/llama3_native_te/dataset.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def create_tokenized_dataset(
3636
load_dataset_kwargs: dict,
3737
max_seq_length: int = 8192,
3838
stride: int = 200,
39-
buffer_size: int = 500_000,
39+
buffer_size: int = 5_000,
4040
use_lazy_tokenization: bool = True,
4141
text_column: str = "text",
42+
tokenize_batch_size: int = 100,
4243
):
4344
"""Create a tokenized dataset with windowing.
4445
@@ -51,20 +52,28 @@ def create_tokenized_dataset(
5152
buffer_size: The buffer size for shuffle.
5253
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
5354
text_column: Name of the column containing genomic sequences (default: "text").
55+
tokenize_batch_size: The batch size for tokenization.
5456
5557
Returns:
5658
Tuple of (tokenized_dataset, tokenizer).
5759
"""
5860
logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}")
5961
dataset = datasets.load_dataset(**load_dataset_kwargs)
60-
logger.info(f"Loaded dataset: {dataset}")
6162

6263
if isinstance(dataset, datasets.IterableDataset):
63-
dataset = datasets.distributed.split_dataset_by_node(
64-
dataset,
65-
rank=distributed_config.rank,
66-
world_size=distributed_config.world_size,
67-
)
64+
# Hugging Face's `split_dataset_by_node` is quite sensitive to the total number of shards -- if the number of
65+
# shards is not perfectly divisible by the world size, it defaults to loading the same shards on all nodes and
66+
# using strided sampling to avoid loading the same data on all nodes. This can be quite inefficient with large
67+
# numbers of shards and workers, so we use `dataset.shard` instead.
68+
if distributed_config.world_size > dataset.num_shards:
69+
logger.info(f"Sharding dataset with {dataset.num_shards} shards with split_dataset_by_node")
70+
dataset = datasets.distributed.split_dataset_by_node(
71+
dataset, rank=distributed_config.rank, world_size=distributed_config.world_size
72+
)
73+
else:
74+
logger.info(f"Sharding dataset with {dataset.num_shards} shards with dataset.shard")
75+
dataset = dataset.shard(num_shards=distributed_config.world_size, index=distributed_config.rank)
76+
6877
dataset = dataset.shuffle(seed=42, buffer_size=buffer_size)
6978

7079
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
@@ -86,33 +95,11 @@ def tokenize_with_windowing(examples):
8695
# Using dataset.map on a non-streaming dataset will automatically perform and cache the transform
8796
tokenized_dataset = dataset.with_transform(tokenize_with_windowing)
8897
else:
89-
# WORKAROUND for OpenGenome2 inconsistent schema:
90-
# OpenGenome2 has inconsistent schemas across shards - some have 'record' column, some don't.
91-
# This causes dataset.column_names to be None for streaming IterableDataset.
92-
#
93-
# For IterableDataset with None column_names (OpenGenome2):
94-
# - Must explicitly list columns to remove: [text_column, "record"]
95-
# - IterableDataset.map() handles missing columns gracefully
96-
#
97-
# For regular Dataset (non-streaming, or streaming with consistent schema like ESM2):
98-
# - Use dataset.column_names (which is available and accurate)
99-
# - Dataset.map() raises error if column doesn't exist
100-
#
101-
# TODO: Remove this workaround once Arc Institute fixes OpenGenome2 schema consistency.
102-
# When all shards have the same columns, dataset.column_names will work for both cases.
103-
if isinstance(dataset, datasets.IterableDataset) and dataset.column_names is None:
104-
# Streaming dataset: column_names may be None due to inconsistent schema
105-
columns_to_remove = [text_column, "record"]
106-
else:
107-
# Non-streaming dataset: use actual column names
108-
columns_to_remove = dataset.column_names
109-
110-
logger.info(f"Applying dataset.map with columns to remove: {columns_to_remove}")
111-
112-
tokenized_dataset = dataset.map(
98+
tokenized_dataset = dataset.select_columns(text_column).map(
11399
tokenize_with_windowing,
114100
batched=True,
115-
remove_columns=columns_to_remove,
101+
batch_size=tokenize_batch_size,
102+
remove_columns=[text_column],
116103
)
117104

118105
return tokenized_dataset, tokenizer
@@ -124,6 +111,7 @@ def create_bshd_dataloader(
124111
load_dataset_kwargs: dict,
125112
micro_batch_size: int,
126113
num_workers: int = 1,
114+
prefetch_factor: int = 4,
127115
max_seq_length: int = 8192,
128116
stride: int = 200,
129117
seed: int = 42,
@@ -142,6 +130,7 @@ def create_bshd_dataloader(
142130
load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
143131
micro_batch_size: The batch size per device.
144132
num_workers: The number of workers to use for the dataloader.
133+
prefetch_factor: The prefetch factor to use for the dataloader.
145134
max_seq_length: The maximum length of sequences (window size).
146135
stride: The stride for windowing (overlap = stride tokens).
147136
seed: The seed to use for the distributed sampler and data collator.
@@ -164,6 +153,7 @@ def create_bshd_dataloader(
164153
buffer_size=buffer_size,
165154
use_lazy_tokenization=use_lazy_tokenization,
166155
text_column=text_column,
156+
tokenize_batch_size=micro_batch_size * prefetch_factor,
167157
)
168158

169159
if isinstance(tokenized_dataset, datasets.IterableDataset):
@@ -207,6 +197,7 @@ def create_bshd_dataloader(
207197
num_workers=num_workers,
208198
pin_memory=True if not use_stateful_dataloader else False,
209199
persistent_workers=num_workers > 0,
200+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
210201
)
211202

212203
return train_dataloader, tokenized_dataset if sampler is None else sampler
@@ -219,6 +210,7 @@ def create_thd_dataloader(
219210
micro_batch_size: int | None = None,
220211
token_micro_batch_size: int | None = None,
221212
num_workers: int = 1,
213+
prefetch_factor: int = 4,
222214
max_seq_length: int = 8192,
223215
stride: int = 200,
224216
buffer_size: int = 500_000,
@@ -238,6 +230,7 @@ def create_thd_dataloader(
238230
token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length
239231
will be used. Defaults to None.
240232
num_workers: The number of workers to use for the dataloader.
233+
prefetch_factor: The prefetch factor to use for the dataloader.
241234
max_seq_length: The maximum length of sequences (window size).
242235
stride: The stride for windowing (overlap = stride tokens).
243236
seed: The seed to use for the distributed sampler and data collator.
@@ -292,6 +285,7 @@ def create_thd_dataloader(
292285
num_workers=num_workers,
293286
pin_memory=True if not use_stateful_dataloader else False,
294287
persistent_workers=num_workers > 0,
288+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
295289
)
296290

297291
return train_dataloader, tokenized_dataset

bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_convergence.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ logger:
5252
frequency: 100
5353

5454
# WandB configuration
55-
wandb_init_args:
55+
wandb:
5656
project: "llama3-genomic-convergence"
5757
name: "tiny-llama-convergence-test"
5858
mode: "online"

bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dataset:
3434
streaming: True
3535

3636
# WandB config
37-
wandb_init_args:
37+
wandb:
3838
name: "llama3_8B_genomic_sanity"
3939
mode: "offline"
4040

bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_1b.yaml

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,23 @@ defaults:
66

77
config_name_or_path: ./model_configs/meta-llama/Llama-3.2-1B
88

9-
wandb_init_args:
9+
config_kwargs:
10+
attn_input_format: thd
11+
12+
use_sequence_packing: true
13+
14+
wandb:
1015
name: lingua-1b-te
1116
project: null # Optional: set to your wandb project name
1217

1318
num_train_steps: 60_000
1419

1520
dataset:
1621
tokenizer_name_or_path: nvidia/Llama-3.1-8B-Instruct-FP8
17-
micro_batch_size: 1
18-
num_workers: 1
19-
max_seq_length: 8192
20-
stride: 1024
22+
micro_batch_size: 4
23+
num_workers: 8
24+
max_seq_length: 4096
25+
stride: 512
2126
buffer_size: 5_000
2227
use_lazy_tokenization: true
2328
use_stateful_dataloader: false
@@ -30,19 +35,28 @@ dataset:
3035
streaming: True
3136

3237
adamw_kwargs:
33-
lr: 3e-4
38+
lr: .003
3439
fused: true
3540
betas: [0.9, 0.95]
36-
eps: 1e-5
37-
weight_decay: 0.1
41+
eps: 0.00000001
42+
weight_decay: 0.033
3843

3944
lr_scheduler_kwargs:
40-
num_warmup_steps: 2_000
41-
num_decay_steps: 60_000
45+
num_warmup_steps: 5_000
46+
num_decay_steps: 55_000 # total_steps - num_warmup_steps = 60_000 - 5_000
47+
min_lr_ratio: 0.000001
4248

4349
# Checkpoint config
4450
checkpoint:
4551
ckpt_dir: null
4652
save_final_model: true
4753
resume_from_checkpoint: true
48-
save_every_n_steps: 1_000
54+
save_every_n_steps: 10_000
55+
56+
profiler:
57+
enabled: false
58+
schedule:
59+
wait: 125
60+
warmup: 125
61+
active: 10
62+
repeat: 1

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,10 @@ dataset:
3333
streaming: True
3434

3535
# WandB config
36-
wandb_init_args:
36+
wandb:
3737
name: ???
3838
project: null # Optional: set to your wandb project name
3939

40-
# mFSDP config
41-
fully_shard_kwargs:
42-
zero_dp_strategy: "optim_grads_params"
43-
calculate_per_token_loss: false
44-
init_model_with_meta_device: ${use_meta_device}
45-
check_for_nan_in_grad: true
46-
grad_reduce_in_fp32: false
47-
preserve_fp32_weights: true
48-
overlap_grad_reduce: true
49-
overlap_param_gather: true
50-
sync_model_each_microbatch: true
51-
average_in_collective: false
52-
5340
# TransformerEngine FP8 config. See
5441
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on
5542
# supported formats.
@@ -63,7 +50,7 @@ fp8_config:
6350

6451
# Optimizer config
6552
adamw_kwargs:
66-
lr: 3e-4
53+
lr: 3e-3
6754
fused: true
6855
betas: [0.9, 0.95]
6956
eps: 1e-5
@@ -72,7 +59,8 @@ adamw_kwargs:
7259
# Learning rate scheduler config
7360
lr_scheduler_kwargs:
7461
num_warmup_steps: 2_000
75-
num_decay_steps: 500_000
62+
num_decay_steps: 498_000
63+
min_lr_ratio: 0.000001
7664

7765
# Checkpoint config
7866
checkpoint:
@@ -83,3 +71,11 @@ checkpoint:
8371

8472
logger:
8573
frequency: 100
74+
75+
profiler:
76+
enabled: false
77+
schedule:
78+
wait: 10
79+
warmup: 10
80+
active: 3
81+
repeat: 1

0 commit comments

Comments
 (0)