@@ -36,9 +36,10 @@ def create_tokenized_dataset(
3636 load_dataset_kwargs : dict ,
3737 max_seq_length : int = 8192 ,
3838 stride : int = 200 ,
39- buffer_size : int = 500_000 ,
39+ buffer_size : int = 5_000 ,
4040 use_lazy_tokenization : bool = True ,
4141 text_column : str = "text" ,
42+ tokenize_batch_size : int = 100 ,
4243):
4344 """Create a tokenized dataset with windowing.
4445
@@ -51,20 +52,28 @@ def create_tokenized_dataset(
5152 buffer_size: The buffer size for shuffle.
5253 use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
5354 text_column: Name of the column containing genomic sequences (default: "text").
55+ tokenize_batch_size: The batch size for tokenization.
5456
5557 Returns:
5658 Tuple of (tokenized_dataset, tokenizer).
5759 """
5860 logger .info (f"Loading dataset with kwargs: { load_dataset_kwargs } " )
5961 dataset = datasets .load_dataset (** load_dataset_kwargs )
60- logger .info (f"Loaded dataset: { dataset } " )
6162
6263 if isinstance (dataset , datasets .IterableDataset ):
63- dataset = datasets .distributed .split_dataset_by_node (
64- dataset ,
65- rank = distributed_config .rank ,
66- world_size = distributed_config .world_size ,
67- )
64+ # Hugging Face's `split_dataset_by_node` is quite sensitive to the total number of shards -- if the number of
65+ # shards is not perfectly divisible by the world size, it defaults to loading the same shards on all nodes and
66+ # using strided sampling to avoid loading the same data on all nodes. This can be quite inefficient with large
67+ # numbers of shards and workers, so we use `dataset.shard` instead.
68+ if distributed_config .world_size > dataset .num_shards :
69+ logger .info (f"Sharding dataset with { dataset .num_shards } shards with split_dataset_by_node" )
70+ dataset = datasets .distributed .split_dataset_by_node (
71+ dataset , rank = distributed_config .rank , world_size = distributed_config .world_size
72+ )
73+ else :
74+ logger .info (f"Sharding dataset with { dataset .num_shards } shards with dataset.shard" )
75+ dataset = dataset .shard (num_shards = distributed_config .world_size , index = distributed_config .rank )
76+
6877 dataset = dataset .shuffle (seed = 42 , buffer_size = buffer_size )
6978
7079 tokenizer = AutoTokenizer .from_pretrained (tokenizer_name_or_path )
@@ -86,33 +95,11 @@ def tokenize_with_windowing(examples):
8695 # Using dataset.map on a non-streaming dataset will automatically perform and cache the transform
8796 tokenized_dataset = dataset .with_transform (tokenize_with_windowing )
8897 else :
89- # WORKAROUND for OpenGenome2 inconsistent schema:
90- # OpenGenome2 has inconsistent schemas across shards - some have 'record' column, some don't.
91- # This causes dataset.column_names to be None for streaming IterableDataset.
92- #
93- # For IterableDataset with None column_names (OpenGenome2):
94- # - Must explicitly list columns to remove: [text_column, "record"]
95- # - IterableDataset.map() handles missing columns gracefully
96- #
97- # For regular Dataset (non-streaming, or streaming with consistent schema like ESM2):
98- # - Use dataset.column_names (which is available and accurate)
99- # - Dataset.map() raises error if column doesn't exist
100- #
101- # TODO: Remove this workaround once Arc Institute fixes OpenGenome2 schema consistency.
102- # When all shards have the same columns, dataset.column_names will work for both cases.
103- if isinstance (dataset , datasets .IterableDataset ) and dataset .column_names is None :
104- # Streaming dataset: column_names may be None due to inconsistent schema
105- columns_to_remove = [text_column , "record" ]
106- else :
107- # Non-streaming dataset: use actual column names
108- columns_to_remove = dataset .column_names
109-
110- logger .info (f"Applying dataset.map with columns to remove: { columns_to_remove } " )
111-
112- tokenized_dataset = dataset .map (
98+ tokenized_dataset = dataset .select_columns (text_column ).map (
11399 tokenize_with_windowing ,
114100 batched = True ,
115- remove_columns = columns_to_remove ,
101+ batch_size = tokenize_batch_size ,
102+ remove_columns = [text_column ],
116103 )
117104
118105 return tokenized_dataset , tokenizer
@@ -124,6 +111,7 @@ def create_bshd_dataloader(
124111 load_dataset_kwargs : dict ,
125112 micro_batch_size : int ,
126113 num_workers : int = 1 ,
114+ prefetch_factor : int = 4 ,
127115 max_seq_length : int = 8192 ,
128116 stride : int = 200 ,
129117 seed : int = 42 ,
@@ -142,6 +130,7 @@ def create_bshd_dataloader(
142130 load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
143131 micro_batch_size: The batch size per device.
144132 num_workers: The number of workers to use for the dataloader.
133+ prefetch_factor: The prefetch factor to use for the dataloader.
145134 max_seq_length: The maximum length of sequences (window size).
146135 stride: The stride for windowing (overlap = stride tokens).
147136 seed: The seed to use for the distributed sampler and data collator.
@@ -164,6 +153,7 @@ def create_bshd_dataloader(
164153 buffer_size = buffer_size ,
165154 use_lazy_tokenization = use_lazy_tokenization ,
166155 text_column = text_column ,
156+ tokenize_batch_size = micro_batch_size * prefetch_factor ,
167157 )
168158
169159 if isinstance (tokenized_dataset , datasets .IterableDataset ):
@@ -207,6 +197,7 @@ def create_bshd_dataloader(
207197 num_workers = num_workers ,
208198 pin_memory = True if not use_stateful_dataloader else False ,
209199 persistent_workers = num_workers > 0 ,
200+ prefetch_factor = prefetch_factor if num_workers > 0 else None ,
210201 )
211202
212203 return train_dataloader , tokenized_dataset if sampler is None else sampler
@@ -219,6 +210,7 @@ def create_thd_dataloader(
219210 micro_batch_size : int | None = None ,
220211 token_micro_batch_size : int | None = None ,
221212 num_workers : int = 1 ,
213+ prefetch_factor : int = 4 ,
222214 max_seq_length : int = 8192 ,
223215 stride : int = 200 ,
224216 buffer_size : int = 500_000 ,
@@ -238,6 +230,7 @@ def create_thd_dataloader(
238230 token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length
239231 will be used. Defaults to None.
240232 num_workers: The number of workers to use for the dataloader.
233+ prefetch_factor: The prefetch factor to use for the dataloader.
241234 max_seq_length: The maximum length of sequences (window size).
242235 stride: The stride for windowing (overlap = stride tokens).
243236 seed: The seed to use for the distributed sampler and data collator.
@@ -292,6 +285,7 @@ def create_thd_dataloader(
292285 num_workers = num_workers ,
293286 pin_memory = True if not use_stateful_dataloader else False ,
294287 persistent_workers = num_workers > 0 ,
288+ prefetch_factor = prefetch_factor if num_workers > 0 else None ,
295289 )
296290
297291 return train_dataloader , tokenized_dataset
0 commit comments