-
Notifications
You must be signed in to change notification settings - Fork 127
Expand file tree
/
Copy pathdataset.py
More file actions
266 lines (229 loc) · 10.6 KB
/
dataset.py
File metadata and controls
266 lines (229 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import datasets
import datasets.distributed
import torch
from torch.utils.data import DataLoader, DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorForLanguageModeling
from collator import (
ContextParallelDataLoaderWrapper,
DataCollatorForContextParallel,
DataCollatorWithFlattening,
TokenPackingDataset,
)
from distributed_config import DistributedConfig
logger = logging.getLogger(__name__)
def create_tokenized_dataset(
distributed_config: DistributedConfig,
tokenizer_name: str,
load_dataset_kwargs: dict,
max_seq_length: int = 1024,
buffer_size: int = 10_000,
use_lazy_tokenization: bool = True,
):
"""Create a tokenized dataset."""
logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}")
dataset = datasets.load_dataset(**load_dataset_kwargs)
logger.info(f"Loaded dataset: {dataset}")
if isinstance(dataset, datasets.IterableDataset):
dataset = datasets.distributed.split_dataset_by_node(
dataset,
rank=distributed_config.rank,
world_size=distributed_config.world_size,
)
dataset = dataset.shuffle(seed=42, buffer_size=buffer_size)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def tokenize_function(examples):
"""Tokenize the protein sequences."""
return tokenizer(
examples["sequence"],
truncation=True,
max_length=max_seq_length,
)
if isinstance(dataset, datasets.Dataset) and use_lazy_tokenization:
# Using dataset.map on a non-streaming dataset will automatically perform and cache the transform, which can
# trigger an expensive tokenization.
tokenized_dataset = dataset.with_transform(tokenize_function)
else:
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
)
return tokenized_dataset, tokenizer
def create_bshd_dataloader(
distributed_config: DistributedConfig,
tokenizer_name: str,
load_dataset_kwargs: dict,
micro_batch_size: int,
num_workers: int,
max_seq_length: int = 1024,
seed: int = 42,
buffer_size: int = 10_000,
use_lazy_tokenization: bool = True,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
):
"""Create a dataloader for the dataset.
Args:
distributed_config: The distributed configuration.
tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub.
load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset.
micro_batch_size: The batch size (number of sequences) per device.
num_workers: The number of workers to use for the dataloader.
max_seq_length: The maximum length of the protein sequences.
seed: The seed to use for the distributed sampler and data collator.
buffer_size: The buffer size to use for the distributed sampler.
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization if the dataset is a
non-streaming datasets.Dataset. Defaults to True.
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
Returns:
A dataloader that can be used for training.
"""
tokenized_dataset, tokenizer = create_tokenized_dataset(
distributed_config=distributed_config,
tokenizer_name=tokenizer_name,
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=max_seq_length,
buffer_size=buffer_size,
use_lazy_tokenization=use_lazy_tokenization,
)
if isinstance(tokenized_dataset, datasets.IterableDataset):
sampler = None
else:
sampler = DistributedSampler(
tokenized_dataset,
rank=distributed_config.rank,
num_replicas=distributed_config.world_size,
seed=seed,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=mlm_probability,
pad_to_multiple_of=max_seq_length,
seed=seed,
)
# TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader
train_dataloader = dataloader_class(
tokenized_dataset,
sampler=sampler,
batch_size=micro_batch_size,
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=not use_stateful_dataloader,
)
return train_dataloader, tokenized_dataset if sampler is None else sampler
def create_thd_dataloader(
distributed_config: DistributedConfig,
tokenizer_name: str,
load_dataset_kwargs: dict,
micro_batch_size: int | None = None,
token_micro_batch_size: int | None = None,
num_workers: int = 1,
max_seq_length: int = 1024,
seed: int = 42,
buffer_size: int = 10_000,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
):
"""Create a dataloader that packs up to the maximum number of tokens per batch.
Args:
distributed_config: The distributed configuration.
tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub.
load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset.
micro_batch_size: The batch size (number of sequences) per device. This will set the token_micro_batch_size to
micro_batch_size * max_seq_length. Defaults to None.
token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length
will be used. Defaults to None.
num_workers: The number of workers to use for the dataloader. For iterable datasets, this should be 1.
max_seq_length: The maximum length of the protein sequences.
seed: The seed to use for the distributed sampler and data collator.
buffer_size: The buffer size to use for the distributed sampler.
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value.
This is useful for context parallelism. Defaults to None.
Returns:
A dataloader that can be used for training.
"""
tokenized_dataset, tokenizer = create_tokenized_dataset(
distributed_config=distributed_config,
tokenizer_name=tokenizer_name,
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=max_seq_length,
buffer_size=buffer_size,
)
assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset."
if token_micro_batch_size is None:
assert micro_batch_size is not None, "Only one of micro_batch_size or token_micro_batch_size can be provided."
token_micro_batch_size = micro_batch_size * max_seq_length
else:
assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided."
assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length."
# For THD, we pad out to the maximum number of tokens per batch for consistent array shapes.
mlm_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=mlm_probability,
seed=seed,
)
data_collator = DataCollatorWithFlattening(
collator=mlm_collator,
pad_to_multiple_of=token_micro_batch_size if pad_sequences_to_be_divisible_by is None else None,
pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by,
)
# TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader
train_dataloader = dataloader_class(
TokenPackingDataset(tokenized_dataset, max_tokens_per_batch=token_micro_batch_size),
batch_size=None, # The TokenPackingDataset will handle the batching.
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=not use_stateful_dataloader,
)
return train_dataloader, tokenized_dataset
def create_cp_dataloader(
*args,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
**kwargs,
):
"""Create a Context-parallel aware dataloader that automatically handles sharding between ranks.
Wraps the output of `create_thd_dataloader` to make it context parallel aware.
Args:
*args: Arguments to pass to `create_thd_dataloader`.
cp_mesh: The context parallel mesh.
**kwargs: Keyword arguments to pass to `create_thd_dataloader`.
Returns:
A tuple of (dataloader, dataset_or_sampler).
"""
# Ensure pad_sequences_to_be_divisible_by is passed to create_thd_dataloader
if kwargs.get("pad_sequences_to_be_divisible_by", None) is None:
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2
if cp_mesh.get_local_rank() == 0:
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
device_mesh=cp_mesh,
)
else:
train_dataloader = None
tokenized_dataset = None
return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset