Skip to content

Commit 6271ab7

Browse files
authored
add HuggingFace streaming support in data input pipeline (#117)
1 parent 975fdb7 commit 6271ab7

21 files changed

+793
-302
lines changed

docs/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ This folder contains documentation for getting started with and using MaxDiffusi
1515
## Training
1616

1717
* **[Common Training Guide](train_README.md)** - Provides a comprehensive guide to training MaxDiffusion models, including script usage, configuration options, and sharding strategies.
18+
19+
## Data Input
20+
21+
* **[Common Data Input Guide](data_README.md)** - Provides a comprehensive guide to data input pipelines.

docs/data_README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Data Input Guide
2+
3+
## Overview
4+
Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `dataset_type`
5+
| Pipeline | Dataset Location | Dataset formats | Features or limitations |
6+
| -------- | ---------------- | --------------- | ----------------------- |
7+
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
8+
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
9+
| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
10+
11+
## Usage examples
12+
13+
### HuggingFace Streaming (dataset_type=hf)
14+
#### Example config for streaming from HuggingFace Hub (no download needed):
15+
```
16+
dataset_type: hf
17+
dataset_name: BleachNick/UltraEdit_500k # for using https://huggingface.co/datasets/BleachNick/UltraEdit_500k
18+
image_column: source_image
19+
caption_column: source_caption
20+
train_split: FreeForm
21+
hf_access_token: '' # provide token if using gated dataset or tokenizer
22+
```
23+
#### Example config for streaming from downloaded data in a GCS bucket:
24+
```
25+
dataset_type: hf
26+
dataset_name: parquet # or json, arrow, etc.
27+
hf_train_files: gs://<bucket>/<folder>/*-train-*.parquet # match the train files
28+
```
29+
30+
### tf.data in-memory dataset (dataset_type=tf)
31+
#### Example config
32+
```
33+
dataset_type: tf
34+
dataset_name: diffusers/pokemon-gpt4-captions # will download https://huggingface.co/datasets/diffusers/pokemon-gpt4-captions
35+
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
36+
# If cache_latents_text_encoder_outputs=True, apply vae to images and encode text when downloading dataset,
37+
# the saved dataset contains latents and text encoder outputs.
38+
cache_latents_text_encoder_outputs: True
39+
```
40+
41+
### tf.data.TFRecordDataset (dataset_type=tfrecord)
42+
#### Example config
43+
```
44+
dataset_type: tfrecord
45+
train_data_dir: gs://<bucket>/<folder> # will use all TFRecord files under the directory
46+
```
47+
48+
## Best Practice
49+
### Multihost Dataloading
50+
In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met.
51+
#### HuggingFace pipeline when streaming from Hub
52+
* When (# of data files) >= (# of hosts loading data), assign files to each host as evenly as possible, some host may ended up with 1 file more than the others. When a host run out of data, it will automatically start another epoch. Since each host run out of data at different speed, different host come to next epoch at different time.
53+
* When (# of data files) < (# of hosts loading data), files are read sequentially with multiple hosts accessing each file, perf can degrade quickly as # of host increases.

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
jax>=0.4.30
22
jaxlib>=0.4.30
3+
grain-nightly
34
google-cloud-storage==2.17.0
45
absl-py
56
datasets
@@ -22,6 +23,6 @@ tensorflow>=2.17.0
2223
tensorflow-datasets>=4.9.6
2324
ruff>=0.1.5,<=0.2
2425
git+https://github.com/mlperf/logging.git
25-
opencv-python==4.10.0.84
26+
opencv-python-headless==4.10.0.84
2627
orbax-checkpoint>=0.5.20
2728
tokenizers==0.20.0

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, config, checkpoint_type):
5454
self.rng = jax.random.PRNGKey(self.config.seed)
5555
devices_array = max_utils.create_device_mesh(config)
5656
self.mesh = Mesh(devices_array, self.config.mesh_axes)
57-
self.total_train_batch_size = max_utils.get_global_batch_size(self.config)
57+
self.total_train_batch_size = self.config.total_train_batch_size
5858

5959
self.checkpoint_manager = create_orbax_checkpoint_manager(
6060
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type

src/maxdiffusion/configs/base14.yml

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,21 @@ ici_tensor_parallelism: 1
125125
# Dataset
126126
# Replace with dataset path or train_data_dir. One has to be set.
127127
dataset_name: 'diffusers/pokemon-gpt4-captions'
128-
# saves transformed dataset of dataset_name.
128+
train_split: 'train'
129+
dataset_type: 'tf'
130+
cache_latents_text_encoder_outputs: True
131+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
132+
# only apply to small dataset that fits in memory
133+
# prepare image latents and text encoder outputs
134+
# Reduce memory consumption and reduce step time during training
135+
# transformed dataset is saved at dataset_save_location
129136
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15'
130137
train_data_dir: ''
131138
dataset_config_name: ''
132139
jax_cache_dir: ''
140+
hf_data_dir: ''
141+
hf_train_files: ''
142+
hf_access_token: ''
133143
image_column: 'image'
134144
caption_column: 'text'
135145
resolution: 512
@@ -147,11 +157,6 @@ checkpoint_every: -1
147157
# enables one replica to read the ckpt then broadcast to the rest
148158
enable_single_replica_ckpt_restoring: False
149159

150-
# Prepare image latents and text encoder outputs
151-
# during dataset creation to reduce memory consumption.
152-
cache_latents_text_encoder_outputs: True
153-
154-
155160
# Training loop
156161
learning_rate: 1.e-7
157162
scale_lr: False

src/maxdiffusion/configs/base21.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,21 @@ ici_tensor_parallelism: 1
127127
# Dataset
128128
# Replace with dataset path or train_data_dir. One has to be set.
129129
dataset_name: 'diffusers/pokemon-gpt4-captions'
130-
# saves transformed dataset of dataset_name.
130+
train_split: 'train'
131+
dataset_type: 'tf'
132+
cache_latents_text_encoder_outputs: True
133+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
134+
# only apply to small dataset that fits in memory
135+
# prepare image latents and text encoder outputs
136+
# Reduce memory consumption and reduce step time during training
137+
# transformed dataset is saved at dataset_save_location
131138
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd21'
132139
train_data_dir: ''
133140
dataset_config_name: ''
134141
jax_cache_dir: ''
142+
hf_data_dir: ''
143+
hf_train_files: ''
144+
hf_access_token: ''
135145
image_column: 'image'
136146
caption_column: 'text'
137147
resolution: 768

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,21 @@ ici_tensor_parallelism: 1
140140
# Dataset
141141
# Replace with dataset path or train_data_dir. One has to be set.
142142
dataset_name: 'diffusers/pokemon-gpt4-captions'
143-
# saves transformed dataset of dataset_name.
143+
train_split: 'train'
144+
dataset_type: 'tf'
145+
cache_latents_text_encoder_outputs: True
146+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147+
# only apply to small dataset that fits in memory
148+
# prepare image latents and text encoder outputs
149+
# Reduce memory consumption and reduce step time during training
150+
# transformed dataset is saved at dataset_save_location
144151
dataset_save_location: '/tmp/pokemon-gpt4-captions'
145152
train_data_dir: ''
146153
dataset_config_name: ''
147154
jax_cache_dir: ''
155+
hf_data_dir: ''
156+
hf_train_files: ''
157+
hf_access_token: ''
148158
image_column: 'image'
149159
caption_column: 'text'
150160
resolution: 512
@@ -162,11 +172,6 @@ checkpoint_every: -1
162172
# enables one replica to read the ckpt then broadcast to the rest
163173
enable_single_replica_ckpt_restoring: False
164174

165-
# Prepare image latents and text encoder outputs
166-
# during dataset creation to reduce memory consumption.
167-
cache_latents_text_encoder_outputs: True
168-
169-
170175
# Training loop
171176
learning_rate: 1.e-7
172177
scale_lr: False

src/maxdiffusion/configs/base_xl.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,21 @@ ici_tensor_parallelism: 1
128128
# Dataset
129129
# Replace with dataset path or train_data_dir. One has to be set.
130130
dataset_name: 'diffusers/pokemon-gpt4-captions'
131-
# saves transformed dataset of dataset_name.
131+
train_split: 'train'
132+
dataset_type: 'tf'
133+
cache_latents_text_encoder_outputs: True
134+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
135+
# only apply to small dataset that fits in memory
136+
# prepare image latents and text encoder outputs
137+
# Reduce memory consumption and reduce step time during training
138+
# transformed dataset is saved at dataset_save_location
132139
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
133140
train_data_dir: ''
134141
dataset_config_name: ''
135142
jax_cache_dir: ''
143+
hf_data_dir: ''
144+
hf_train_files: ''
145+
hf_access_token: ''
136146
image_column: 'image'
137147
caption_column: 'text'
138148
resolution: 1024
@@ -150,10 +160,6 @@ checkpoint_every: -1
150160
# enables one replica to read the ckpt then broadcast to the rest
151161
enable_single_replica_ckpt_restoring: False
152162

153-
# Prepare image latents and text encoder outputs
154-
# during dataset creation to reduce memory consumption.
155-
cache_latents_text_encoder_outputs: True
156-
157163
# Training loop
158164
learning_rate: 4.e-7
159165
scale_lr: False
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import warnings
18+
import datasets
19+
from datasets import load_dataset
20+
from datasets.distributed import split_dataset_by_node
21+
import grain.python as grain
22+
23+
from maxdiffusion import max_logging
24+
from maxdiffusion import multihost_dataloading
25+
26+
27+
def make_hf_streaming_iterator(
28+
config,
29+
dataloading_host_index,
30+
dataloading_host_count,
31+
mesh,
32+
global_batch_size,
33+
tokenize_fn=None,
34+
image_transforms_fn=None,
35+
hf_batch_factor=4,
36+
):
37+
"""Streaming data from HF Hub or GCS buckect.
38+
No download regardless of config.cache_latents_text_encoder_outputs"""
39+
ds = load_dataset(
40+
config.dataset_name,
41+
split=config.train_split,
42+
data_dir=config.hf_data_dir,
43+
data_files=config.hf_train_files,
44+
streaming=True,
45+
token=config.hf_access_token,
46+
)
47+
48+
ds = ds.shuffle(seed=config.seed)
49+
ds = ds.select_columns([config.caption_column, config.image_column])
50+
51+
if tokenize_fn:
52+
ds = ds.map(
53+
function=tokenize_fn,
54+
batched=True,
55+
batch_size=hf_batch_factor * config.total_train_batch_size,
56+
remove_columns=[config.caption_column],
57+
)
58+
59+
if image_transforms_fn:
60+
ds = ds.map(
61+
function=image_transforms_fn,
62+
batched=True,
63+
batch_size=hf_batch_factor * config.total_train_batch_size,
64+
remove_columns=[config.image_column],
65+
)
66+
67+
ds = HFDataSource(
68+
ds,
69+
dataloading_host_index,
70+
dataloading_host_count,
71+
)
72+
dummy_index_sampler = grain.IndexSampler(
73+
num_records=len(ds),
74+
num_epochs=1,
75+
shard_options=grain.ShardOptions(
76+
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
77+
),
78+
shuffle=False,
79+
seed=0,
80+
)
81+
operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)]
82+
dataloader = grain.DataLoader(
83+
data_source=ds,
84+
operations=operations,
85+
sampler=dummy_index_sampler,
86+
worker_count=1, # only supports one worker for now, more workers results in duplicated data
87+
worker_buffer_size=1,
88+
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=hf_batch_factor * config.total_train_batch_size),
89+
)
90+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
91+
return train_iter
92+
93+
94+
class HFDataSource(grain.RandomAccessDataSource):
95+
"""A class that makes HuggingFace IterableDataset a grain datasource without random access support"""
96+
97+
def __init__(
98+
self,
99+
dataset: datasets.IterableDataset,
100+
dataloading_host_index: int,
101+
dataloading_host_count: int,
102+
):
103+
self.dataset = dataset
104+
self.dataloading_host_count = dataloading_host_count
105+
self.dataloading_host_index = dataloading_host_index
106+
self.n_shards = dataset.n_shards
107+
self._check_shard_count()
108+
self.current_shard = dataloading_host_index
109+
self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard)
110+
self.data_iter = None
111+
112+
def _check_shard_count(self):
113+
if self.n_shards < self.dataloading_host_count:
114+
warnings.warn(
115+
f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, "
116+
"smaller than number of host loading data. This is known to lead to inefficient dataloading. "
117+
"see https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/docs/data_README.md#best-practice"
118+
)
119+
self.n_shards = self.dataloading_host_count
120+
121+
def _update_shard(self):
122+
new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards
123+
max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}")
124+
self.current_shard = new_shard
125+
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
126+
self.data_iter = iter(self.dataset_shard)
127+
128+
def __len__(self):
129+
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
130+
a fake length bigger than the dataset is returned"""
131+
return 10_000_000_000
132+
133+
def __getitem__(self, index):
134+
"""Since HuggingFace IterableDataset does not support random access by index.
135+
The next item in the iterator is returned."""
136+
if not self.data_iter:
137+
self.data_iter = iter(self.dataset_shard)
138+
139+
while True:
140+
try:
141+
data = next(self.data_iter)
142+
return data
143+
except StopIteration:
144+
self._update_shard()

0 commit comments

Comments
 (0)