|
| 1 | +""" |
| 2 | + Copyright 2024 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + """ |
| 16 | + |
| 17 | +import warnings |
| 18 | +import datasets |
| 19 | +from datasets import load_dataset |
| 20 | +from datasets.distributed import split_dataset_by_node |
| 21 | +import grain.python as grain |
| 22 | + |
| 23 | +from maxdiffusion import max_logging |
| 24 | +from maxdiffusion import multihost_dataloading |
| 25 | + |
| 26 | + |
| 27 | +def make_hf_streaming_iterator( |
| 28 | + config, |
| 29 | + dataloading_host_index, |
| 30 | + dataloading_host_count, |
| 31 | + mesh, |
| 32 | + global_batch_size, |
| 33 | + tokenize_fn=None, |
| 34 | + image_transforms_fn=None, |
| 35 | + hf_batch_factor=4, |
| 36 | +): |
| 37 | + """Streaming data from HF Hub or GCS buckect. |
| 38 | + No download regardless of config.cache_latents_text_encoder_outputs""" |
| 39 | + ds = load_dataset( |
| 40 | + config.dataset_name, |
| 41 | + split=config.train_split, |
| 42 | + data_dir=config.hf_data_dir, |
| 43 | + data_files=config.hf_train_files, |
| 44 | + streaming=True, |
| 45 | + token=config.hf_access_token, |
| 46 | + ) |
| 47 | + |
| 48 | + ds = ds.shuffle(seed=config.seed) |
| 49 | + ds = ds.select_columns([config.caption_column, config.image_column]) |
| 50 | + |
| 51 | + if tokenize_fn: |
| 52 | + ds = ds.map( |
| 53 | + function=tokenize_fn, |
| 54 | + batched=True, |
| 55 | + batch_size=hf_batch_factor * config.total_train_batch_size, |
| 56 | + remove_columns=[config.caption_column], |
| 57 | + ) |
| 58 | + |
| 59 | + if image_transforms_fn: |
| 60 | + ds = ds.map( |
| 61 | + function=image_transforms_fn, |
| 62 | + batched=True, |
| 63 | + batch_size=hf_batch_factor * config.total_train_batch_size, |
| 64 | + remove_columns=[config.image_column], |
| 65 | + ) |
| 66 | + |
| 67 | + ds = HFDataSource( |
| 68 | + ds, |
| 69 | + dataloading_host_index, |
| 70 | + dataloading_host_count, |
| 71 | + ) |
| 72 | + dummy_index_sampler = grain.IndexSampler( |
| 73 | + num_records=len(ds), |
| 74 | + num_epochs=1, |
| 75 | + shard_options=grain.ShardOptions( |
| 76 | + shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False |
| 77 | + ), |
| 78 | + shuffle=False, |
| 79 | + seed=0, |
| 80 | + ) |
| 81 | + operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)] |
| 82 | + dataloader = grain.DataLoader( |
| 83 | + data_source=ds, |
| 84 | + operations=operations, |
| 85 | + sampler=dummy_index_sampler, |
| 86 | + worker_count=1, # only supports one worker for now, more workers results in duplicated data |
| 87 | + worker_buffer_size=1, |
| 88 | + read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=hf_batch_factor * config.total_train_batch_size), |
| 89 | + ) |
| 90 | + train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh) |
| 91 | + return train_iter |
| 92 | + |
| 93 | + |
| 94 | +class HFDataSource(grain.RandomAccessDataSource): |
| 95 | + """A class that makes HuggingFace IterableDataset a grain datasource without random access support""" |
| 96 | + |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + dataset: datasets.IterableDataset, |
| 100 | + dataloading_host_index: int, |
| 101 | + dataloading_host_count: int, |
| 102 | + ): |
| 103 | + self.dataset = dataset |
| 104 | + self.dataloading_host_count = dataloading_host_count |
| 105 | + self.dataloading_host_index = dataloading_host_index |
| 106 | + self.n_shards = dataset.n_shards |
| 107 | + self._check_shard_count() |
| 108 | + self.current_shard = dataloading_host_index |
| 109 | + self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard) |
| 110 | + self.data_iter = None |
| 111 | + |
| 112 | + def _check_shard_count(self): |
| 113 | + if self.n_shards < self.dataloading_host_count: |
| 114 | + warnings.warn( |
| 115 | + f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, " |
| 116 | + "smaller than number of host loading data. This is known to lead to inefficient dataloading. " |
| 117 | + "see https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/docs/data_README.md#best-practice" |
| 118 | + ) |
| 119 | + self.n_shards = self.dataloading_host_count |
| 120 | + |
| 121 | + def _update_shard(self): |
| 122 | + new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards |
| 123 | + max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}") |
| 124 | + self.current_shard = new_shard |
| 125 | + self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard) |
| 126 | + self.data_iter = iter(self.dataset_shard) |
| 127 | + |
| 128 | + def __len__(self): |
| 129 | + """Return length of the HF dataset. Since HuggingFace IterableDataset does not have length, |
| 130 | + a fake length bigger than the dataset is returned""" |
| 131 | + return 10_000_000_000 |
| 132 | + |
| 133 | + def __getitem__(self, index): |
| 134 | + """Since HuggingFace IterableDataset does not support random access by index. |
| 135 | + The next item in the iterator is returned.""" |
| 136 | + if not self.data_iter: |
| 137 | + self.data_iter = iter(self.dataset_shard) |
| 138 | + |
| 139 | + while True: |
| 140 | + try: |
| 141 | + data = next(self.data_iter) |
| 142 | + return data |
| 143 | + except StopIteration: |
| 144 | + self._update_shard() |
0 commit comments