From 45b0bcee9b3e98b1efae2919bbf050ff84758e56 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 21 Nov 2024 19:05:26 -0500 Subject: [PATCH 01/79] Add distributed datasets --- .../stateful_dataloader/ibm_rescalable.py | 604 ++++++++++++++++++ 1 file changed, 604 insertions(+) create mode 100644 torchdata/stateful_dataloader/ibm_rescalable.py diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py new file mode 100644 index 000000000..f457d9d0b --- /dev/null +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -0,0 +1,604 @@ +import logging +import math +import os +from copy import deepcopy +from typing import Any, Callable, List + +import torch +import torch.distributed as dist +import torch.distributed.tensor as dtensor +import torch.utils.data as data + +from .stateful_dataloader import StatefulDataLoader + +""" +The following distributed dataloaders are designed around 3 main principles: + +1. Efficient, asynchronous operation. Workers on different devices do not communicate. +2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator + loading from disk and additional layers adding levels of post-processing (shuffling, + packing, padding, rescaling, etc.). +3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal + state that can be written/read on disk via implemented recursive `state_dict()` and + `load_state_dict()` calls. Any values that should be saved to state can be designated + 'state_params' and will be automatically included in the state dict. States must be + valid targets of torch.tensor(). +4. Rescalability. Users can save and load checkpoints to/from different numbers of workers + without losing the global state. This is accomplished by splitting the global state over + a predefined large number of small partitions, each of which tracks its own individual + state. Rescaling is accomplished by re-distributing these shards over the physical workers. + +Our loaders obey the following type hierarchy: +torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. +`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a +single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, +then applying some sort of post-processing and yielding the result. Users build data processing +pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers, +which is then passed to the torch DataLoader. + +It is likely that this can be merged into the existing Nodes structure, but we leave this for +future work, for now. +""" + + +def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. + """ + return itemlist[(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize] + + +def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, + and return the span including all owned items, fractional or otherwise. + """ + start = math.floor(len(itemlist) * rank / worldsize) + end = math.ceil(len(itemlist) * (rank + 1) / worldsize) + return itemlist[start:end] + + +class _StatefulDataset(data.IterableDataset): + """ + Stub for stateful datasets, extends data.IterableDataset with state_dict methods. + All subclasses should specify the params to be considered stateful via self.state_params. + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + ): + assert rank >= 0, f"Rank {rank} must be a positive integer" + assert worldsize > rank, f"Worldsize {worldsize} must be greater than rank {rank}" + assert datapath is None or ( + os.path.isdir(datapath) and len(os.listdir(datapath)) > 0 + ), f"Data path {datapath} must be a non-empty folder or None" + self.state_params: List[str] = [] + + # Default fields + self.datapath = datapath + self.rank = rank + self.worldsize = worldsize + self.local_worldsize = -1 + + # Setup / loading flags + self.is_setup = False + + def setup(self): + """ + This method should contain all setup depending on datapath or rank. + It is called after init, but immediately before any other operation. + Certain operations higher up in the pipeline may change rank or datapath + after init (for example, wrapping in a subdataset sampler layer, or copying + to worker processes), so all rank- and datapth- dependent ops are deferred to + this function. + Currently, this function simply adjusts rank/worldsize to account for + multiprocess dataloaders. + """ + if not self.is_setup: + self.is_setup = True + # Perform adjustment only if not already adjusted (i.e. via _WrapperDataset) + if self.local_worldsize == -1: + info = data.get_worker_info() + if info is None or info.num_workers == 1: + # No multi-worker rank adjustment needed + self.local_worldsize = 1 + else: + self.local_worldsize = info.num_workers + self.worldsize = self.worldsize * self.local_worldsize + self.rank = self.local_worldsize * self.rank + info.id + + def statename(self, x: str): + # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline + return self.__class__.__name__ + "." + x + + def state_dict(self): + """ + Retrieve all state_params (each worker/process produces its own state dict shard). + On the off chance that you're saving a checkpoint with zero steps, run setup first. + """ + self.setup() + return {self.statename(flag): getattr(self, flag) for flag in self.state_params} + + def load_state_dict(self, state_dict): + """ + Run setup if needed, and apply all applicable state_params from the state_dict. + """ + self.setup() + [setattr(self, flag, state_dict[self.statename(flag)]) for flag in self.state_params] + + +class _WrapperDataset(_StatefulDataset): + """ + Stub for nested wrappers of _StatefulDatasets. Extends state fns with recursion. + Requires a single instantiated sub-dataset (which may be replicated during setup fn). + """ + + def __init__( + self, + dataset: _StatefulDataset, + ): + self.dataset = dataset + # Inherit default flags from sub-dataset + super().__init__(self.dataset.datapath, self.dataset.rank, self.dataset.worldsize) + + def setup(self): + """ + Datapath/rank/worldsize percolate upwards recursively during initialization, so + now we project any desired changes downward, also recursively. + We also project local_worldsize downward to prevent subsequent layers from + further inflating the rank/worldsize - we only need to account for multiprocessing once! + Any code overriding this function should still include this functionality. + """ + if not self.is_setup: + super().setup() + self.dataset.datapath = self.datapath + self.dataset.rank = self.rank + self.dataset.worldsize = self.worldsize + self.dataset.local_worldsize = self.local_worldsize + self.dataset.setup() + + def load_state_dict(self, state_dict): + """ + Sets all specified flags at the current level, then recurses into wrapped dataset. + """ + self.setup() + super().load_state_dict(state_dict) + self.dataset.load_state_dict(state_dict) + + def state_dict(self): + """ + Fetches state dict recursively from wrapped layers, then adds specified flags. + Overlapping flags are overwritten with a warning. + """ + self.setup() + out = self.dataset.state_dict() + state = super().state_dict() + for flag in self.state_params: + if flag in out: + logging.warning( + f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " + + f"Overwriting with value {state[flag]}" + ) + out.update(state) + return out + + +#### ------------------------- DATASET LAYERS ------------------------- #### + + +class PreprocessDataset(_WrapperDataset): + """ + Wrapper for a _StatefulDataset that applies a specified preprocessing + or augmentation function to dataset outputs. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + aug_fn : function (any -> any) + The augmentation function to apply to each dataset item. + """ + + def __init__( + self, + dataset: _StatefulDataset, + aug_fn: Callable, + ): + super().__init__(dataset) + self.aug_fn = aug_fn + + def __iter__(self): + dataset = iter(self.dataset) + while True: + out = next(dataset) + yield self.aug_fn(out) + + +class SamplingDataset(_WrapperDataset): + """ + A _WrapperDataset implementing percentage-based sampling: weights can be floats, and the + number of tokens seen from each subdataset will match those weights as closely as possible. + This is accomplished by maintaining a _StatefulDataset for each subdataset, and tracking + the number of tokens emitted by each. Whichever loader is furthest from its target will be + the next to pass a document. + Relies on eos token to determine document boundaries, so must sit below BufferDataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects directory to contain subfolders, + which in turn contain shard files. + dataset : _StatefulDataset + Fully instantiated dataset. Cloned across desired subdatasets during setup. + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + datasets : list[str] | None + A list of subdatasets to draw from. If None, draws from all subfolders of datapath. + weights : list(float) | None + Weights describing what percent of emitted tokens should come from each subdataset. + Need not sum to 1. If None, tokens are drawn evenly. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + datapath: str, + dataset: _StatefulDataset, + delimiter_token: Any, + datasets=None, + weights=None, + verbose=False, + ): + super().__init__(dataset) + self.datapath = datapath + self.delimiter = delimiter_token + self.verbose = verbose + self.datasets = ( + datasets + if datasets is not None + else [f for f in os.listdir(datapath) if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f] + ) + assert len(self.datasets) > 0, "You must specify at least one dataset" + + if weights is not None: + assert len(weights) == len( + self.datasets + ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" + for w in weights: + assert w > 0, f"Sampling rate {w} must be positive" + self.weights = [1] * len(self.datasets) if weights is None else weights + self.weights = [w / sum(self.weights) for w in self.weights] + + self.tokens_seen = [0] * len(self.datasets) + + self.current_iterator = -1 + self.state_params = ["tokens_seen", "current_iterator"] + + def setup(self): + if not self.is_setup: + _StatefulDataset.setup(self) + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append(deepcopy(self.dataset)) + self.data[-1].datapath = os.path.join(self.datapath, d) + self.data[-1].rank = self.rank + self.data[-1].worldsize = self.worldsize + self.data[-1].local_worldsize = self.local_worldsize + if self.verbose: + logging.info( + f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + [d.setup() for d in self.data] + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + if self.current_iterator != -1: + # Finish current document + out = next(data[self.current_iterator]) + self.tokens_seen[self.current_iterator] += len(out) + if out[-1] == self.delimiter: + self.current_iterator = -1 + yield out + else: + # Choose new subdataset to draw from + # (whichever is currently most underrepresented compared to target rate) + offset = [ + self.weights[i] - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) + for i in range(len(self.datasets)) + ] + offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] + self.current_iterator = offset_argmax + + def state_dict(self): + self.setup() + # Manually add state of all subloaders to self state + iterator_states = [d.state_dict() for d in self.data] + assert len(iterator_states) > 0, f"Worker {self.rank} owns no datasets" + # Flip list[dict[any]] to dict[list[any]] + prefix = self.statename("states.") + out = {prefix + k: [d[k] for d in iterator_states] for k in iterator_states[0].keys()} + out.update(_StatefulDataset.state_dict(self)) + return out + + def load_state_dict(self, state_dict): + self.setup() + # Load stats + _StatefulDataset.load_state_dict(self, state_dict) + # Load sub-iterator states + prefix = self.statename("states.") + # Flip dict[list[any]] to list[dict[any]] + iterator_states = [ + {k[k.find(prefix) + len(prefix) :]: v[i] for k, v in state_dict.items() if prefix in k} + for i in range(len(self.data)) + ] + # Load individual state sub-dicts + [self.data[i].load_state_dict(iterator_states[i]) for i in range(len(self.data))] + + +class DummyDataset(_StatefulDataset): + """ + A dummy base dataset for demo purposes. + + Normally this dataset would be responsible for using rank, datapath and worldsize arguments + to perform dataset partitioning, and implement repeating iteration over its particular data shard. + + Spits out random sequences of desired vocab size / seq length as lists. + Places delimiter token at end of each sequence (used by SamplingDataset). + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + seed: int = 42, + vocab: int = 100, + seqlen: int = 64, + ): + super().__init__(datapath, rank, worldsize) + self.vocab = vocab + self.seqlen = seqlen + self.delimiter = delimiter_token + # Ensure different seeds across ranks and datasets, for demo purposes + seed = seed + self.rank + len(datapath) * 100 + self.generator = torch.Generator().manual_seed(seed) + self.g_state = None + self.state_params = ["g_state"] + + def __iter__(self): + while True: + out = torch.rand(self.seqlen, generator=self.generator) + out = out.mul(self.vocab).int().tolist() + out[-1] = self.delimiter + yield out + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state().tolist() + return super().state_dict()() + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) + + +class ScalableShardDataset(_WrapperDataset): + """ + A _WrapperDataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of smaller StatefulDatasets, cloned from the + original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. + Rescaling only works when this layer wraps all other layers that contribute to state_dict. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset. Cloned into logical workers during setup fn. + n_logical_shards : int + Total number of logical shards. Must be a multiple of world size. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + dataset: _StatefulDataset, + n_logical_shards: int = 2048, + verbose=False, + ): + super().__init__(dataset) + assert ( + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert n_logical_shards > 0, f"n_logical_shards {n_logical_shards} must be a positive integer" + + self.total_shards = n_logical_shards + self.verbose = verbose + + # Fields to be populated during setup / subdataset setup + self.data: List[_StatefulDataset] = [] + self.logicals_owned: List[int] = [] + self.n_logicals = 0 + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = 0 + self.load_worldsize = self.worldsize + + self.state_params = ["current_reader"] # self.data states are handled manually + + def setup(self): + if not self.is_setup: + _StatefulDataset.setup(self) + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert ( + len(self.logicals_owned) == self.n_logicals + ), "(world size * num workers) does not divide logical shards evenly" + + # Build logical shards + for i in range(self.n_logicals): + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].local_worldsize = 1 + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + [d.setup() for d in self.data] + + def __iter__(self): + self.setup() + # Grab one item at a time, iterating over owned logical shards + data = [iter(d) for d in self.data] + while True: + ind = self.current_reader + # Read doc + out = next(data[ind]) + # Update state + self.current_reader = (self.current_reader + 1) % self.n_logicals + yield out + + def state_dict(self): + self.setup() + # Recursive fetch + logical_shard_states = [d.state_dict() for d in self.data] + assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???" + # Flip list[dict[Any]] to dict[list[Any]] + state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()} + state_dict.update(_StatefulDataset.state_dict(self)) + + # Convert to tensor form + out = {} + for k, v in state_dict.items(): + v = torch.tensor(v) + if len(v.shape) == 0: + k = k + ".scalar" + v = v.unsqueeze(0) + out[k] = v + + return out + + def load_state_dict(self, state_dict): + self.setup() + + # Convert back to lists and scalars + def detorchify(k, v): + v = v.tolist() + if ".scalar" in k: + k = k[:-7] + v = v[0] + return k, v + + plain_dict = {} + for k, v in state_dict.items(): + k, v = detorchify(k, v) + plain_dict[k] = v + state_dict = plain_dict + + # Assemble logical shard states + # TODO: how is this handling non-resharding state_params when resharding??? + _StatefulDataset.load_state_dict(self, state_dict) + # Remove all non-resharding state + [state_dict.pop(self.statename(n)) for n in self.state_params] + # Flip dict[list[any]] to list[dict[any]] + logical_shard_states = [{k: v[i] for k, v in state_dict.items()} for i in range(self.n_logicals)] + + # Load values + for i in range(self.n_logicals): + self.data[i].load_state_dict(logical_shard_states[i]) + + +#### ------------------------- CHECKPOINT FUNCTIONS ------------------------- #### + + +def __pop_dstate(state, device_mesh, placements): + """ + Removes worker states from the StatefulDataLoader state dict, and assembles them + into a separate list of dicts for distributed checkpointing. + """ + dstate = state["_snapshot"]["_worker_snapshots"] + dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))] + # Flip list[dict[tensor]] to dict[list[tensor]], and concat + dstate = {k: torch.cat([d[k] for d in dstate], 0) for k in dstate[0]} + # Construct dtensors from tensors + dstate = { + k: dtensor.DTensor.from_local( + v, + device_mesh, + placements, + ) + for k, v in dstate.items() + } + return dstate + + +def save_distributed_state_dict( + loader: StatefulDataLoader, + path: str, + device_mesh=None, +): + rank = loader.dataset.rank + state = deepcopy(loader.state_dict()) + dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) + # Write distributed state dict + writer = dist.checkpoint.FileSystemWriter(path) + dist.checkpoint.save( + dstate, + writer, + ) + # Write nondistributed state dict + torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) + + +def load_distributed_state_dict( + loader: StatefulDataLoader, + path: str, + device_mesh=None, +): + base = loader.state_dict() + nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] + rank = loader.dataset.rank + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) + # Read nondistributed state dict + ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x]) + # Check that number of loaders matches + if ckp_ws == loader.dataset.worldsize: + state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth")) + # Check that number of workers matches + if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: + state = base + else: + # On mismatch, discard saved non-reshardable loader state and start fresh + state = base + # Read distributed state dict + reader = dist.checkpoint.FileSystemReader(path) + dist.checkpoint.load_state_dict( + dstate, + reader, + ) + # Get local tensors from dtensors, and slice over workers + dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()} + # Flip dict[list[tensor]] to list[dict[tensor]] + dstate = [{k: v[i] for k, v in dstate.items()} for i in range(nworkers)] + # Re-insert worker states into loader state + for i in range(nworkers): + state["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"] = dstate[i] + # Load into loader + loader.load_state_dict(state) From e486614c757aee390c57d073e1b14fc341b669f5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 21 Nov 2024 19:09:19 -0500 Subject: [PATCH 02/79] Formatting, commenting --- torchdata/stateful_dataloader/ibm_rescalable.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index f457d9d0b..99328e1d7 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -554,6 +554,13 @@ def save_distributed_state_dict( path: str, device_mesh=None, ): + """ + Retrieves dataloader state dict, and separates worker states from loader state. + Loader state is not rescalable, and is saved using normal torch.save. + It is discarded when rescaling. + Rescalable worker states are compiled into a dtensor across ranks, and saved + using pytorch distributed checkpointing. + """ rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) @@ -572,6 +579,13 @@ def load_distributed_state_dict( path: str, device_mesh=None, ): + """ + Retrieves dataloader state dict, and separates worker states from loader state. + If not rescaling, load saved dataloader state. + Rescalable worker states are retrieved using pytorch distributed checkpointing. + States are distributed over workers, and ScalableShardDataset will handle + partitioning and re-assignment of available states into logical ranks. + """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] rank = loader.dataset.rank From 10e45b9ff0e4d65d07dcb940153d708f5fa31416 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 19:58:27 -0500 Subject: [PATCH 03/79] Add demo script --- examples/ibm_rescaling/rescaling_demo.py | 146 +++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 examples/ibm_rescaling/rescaling_demo.py diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py new file mode 100644 index 000000000..8d71b35d7 --- /dev/null +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -0,0 +1,146 @@ +import argparse +import os + +import torch +from torch import distributed as dist + +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.ibm_rescalable import ( + DummyDataset, + PreprocessDataset, + SamplingDataset, + ScalableShardDataset, + load_distributed_state_dict, + save_distributed_state_dict, +) + +# This example script validates the rescaling behavior of the ibm rescalable distributed datasets. +# On first run, saves a distributed checkpoint to the desired location. +# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers) +# and verifies that previous data is not revisited, while upcoming data is. + +# Example usage: +# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6 + + +parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints") +parser.add_argument("--ckpt_path", type=str, default="./rescale_test") +parser.add_argument( + "--logical_shards", + type=int, + default=96, + help="Total number of data partitions. (worldsize * n_workers) must divide this evenly.", +) +parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device") +parser.add_argument("--b_size", type=int, default=1, help="Number of data points per step per device") +parser.add_argument("--seed", type=int, default=42) + +args = parser.parse_args() + +# Setup +rank = int(os.getenv("RANK", 0)) +world_size = int(os.getenv("WORLD_SIZE", 1)) +dist.init_process_group() +mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) +placement = [dist.tensor.placement_types.Shard(0)] + +# Build dataloader +data = DummyDataset("not_a_real_datapath", rank, world_size, delimiter_token=-1, seed=args.seed) +# Pretend that we're sampling over multiple sub-datasets +data = SamplingDataset( + "not_a_real_datapath", + data, + delimiter_token=-1, + datasets=["sub_dataset", "second_subdataset", "small_subdataset"], + weights=[12, 17, 5], +) +# Apply rescalability layer +data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +# Statelessly convert all outputs to tensors +data = PreprocessDataset(data, torch.tensor) +# Wrap in StatefulDataLoader +data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) + +# If checkpoint does not exist, create it +if not os.path.exists(args.ckpt_path) or len(os.listdir(cfg.ckpt_save_path)) == 0: + os.makedirs(args.ckpt_path, exist_ok=True) + # Iterate, assemble values to exclude + if rank == 0: + print("No existing checkpoint. Processing 100 steps.") + + avoid = [] + for i, inp in enumerate(data): + if i == 100: + if rank == 0: + print("Iteration complete!") + save_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + break + avoid.append(inp) + avoid = torch.cat(avoid) + # Get all vals onto each rank + avoid = dist.tensor.DTensor.from_local( + avoid, + mesh, + placement, + ).full_tensor() + + # Continue, assemble values to include + load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + if rank == 0: + print("DCP state loaded!") + + include = [] + for i, inp in enumerate(data): + if i == 10: + break + include.append(inp) + include = torch.cat(include) + if rank == 0: + print("Iteration round 2 complete!") + # Get all vals onto each rank + include = dist.tensor.DTensor.from_local(include, mesh, placement).full_tensor() + + if rank == 0: + torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth")) + torch.save(include, os.path.join(args.ckpt_path, "include.pth")) + print( + "Generation complete! Please rerun (with different world size / workers if desired) to complete the check." + ) + +# If checkpoint does exist, load and take 100 steps. +# Ensure avoid values are avoided, and all include values are included. +else: + if rank == 0: + print("Checkpoint detected!") + load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + + vals = [] + for i, inp in enumerate(data): + if i == 100: + break + vals.append(inp) + vals = torch.cat(vals) + # Get all vals onto each rank + vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() + + # Perform avoid/include checks on rank 0 only + if rank == 0: + avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) + include = torch.load(os.path.join(args.ckpt_path, "include.pth")) + + def _in(v, m): + # Returns whether vector v is a row of matrix m (both tensors) + return m.sub(v[None]).abs().sum(1).sign().prod().bool().logical_not().item() + + # Avoid check + for i, x in enumerate(avoid.split(1)): + assert not _in(x[0], vals), i + print("Check passed: seen data was not revisited!") + + # Include check + for i, x in enumerate(include.split(1)): + assert _in(x[0], vals), i + print("Check passed: upcoming data appears as expected!") + +dist.barrier() +dist.destroy_process_group() From 10a6f66ec857fb0a2b92d155ddfedad2da9a45d3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:13:34 -0500 Subject: [PATCH 04/79] Datapath None --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 8d71b35d7..bd9885800 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,10 +45,10 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset("not_a_real_datapath", rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets data = SamplingDataset( - "not_a_real_datapath", + None, data, delimiter_token=-1, datasets=["sub_dataset", "second_subdataset", "small_subdataset"], From 02818977f358dd48a0ed8ae1ec84c7012e24bfcd Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:20:00 -0500 Subject: [PATCH 05/79] Shift dummydata seeding to setup, dummy path handling --- examples/ibm_rescaling/rescaling_demo.py | 2 +- torchdata/stateful_dataloader/ibm_rescalable.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index bd9885800..5e99068b8 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -48,7 +48,7 @@ data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets data = SamplingDataset( - None, + "", data, delimiter_token=-1, datasets=["sub_dataset", "second_subdataset", "small_subdataset"], diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 99328e1d7..ee8571cac 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -369,11 +369,15 @@ def __init__( self.seqlen = seqlen self.delimiter = delimiter_token # Ensure different seeds across ranks and datasets, for demo purposes - seed = seed + self.rank + len(datapath) * 100 - self.generator = torch.Generator().manual_seed(seed) + self.seed = seed + self.generator = None self.g_state = None self.state_params = ["g_state"] + def setup(self): + super().setup() + self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) + def __iter__(self): while True: out = torch.rand(self.seqlen, generator=self.generator) From a175c3c8d7745c4e843bdb9e5f85e55820e58730 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:23:53 -0500 Subject: [PATCH 06/79] Actually create dummy data folders --- examples/ibm_rescaling/rescaling_demo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 5e99068b8..9d05bb3ee 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -47,11 +47,13 @@ # Build dataloader data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets +subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] +[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] data = SamplingDataset( - "", + os.path.join(args.ckpt_path, "data"), data, delimiter_token=-1, - datasets=["sub_dataset", "second_subdataset", "small_subdataset"], + datasets=subdatas, weights=[12, 17, 5], ) # Apply rescalability layer From 957a5bf7af392d63088921e2bb4ac4e788c39a6f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:26:04 -0500 Subject: [PATCH 07/79] Remove cfg ref --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 9d05bb3ee..241bc3848 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -64,7 +64,7 @@ data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) # If checkpoint does not exist, create it -if not os.path.exists(args.ckpt_path) or len(os.listdir(cfg.ckpt_save_path)) == 0: +if not os.path.exists(args.ckpt_path) or len(os.listdir(args.ckpt_path)) == 0: os.makedirs(args.ckpt_path, exist_ok=True) # Iterate, assemble values to exclude if rank == 0: From 2e9bdf09014e2c6a86877279e415b4d9d3a86183 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:28:01 -0500 Subject: [PATCH 08/79] Remove double () call --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ee8571cac..60c0ec331 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -388,7 +388,7 @@ def __iter__(self): def state_dict(self): # Write generator state manually self.g_state = self.generator.get_state().tolist() - return super().state_dict()() + return super().state_dict() def load_state_dict(self, state_dict): super().load_state_dict(state_dict) From e475eeca02fe3b9889147a1a9b6a8d53c07b9bcf Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:31:25 -0500 Subject: [PATCH 09/79] Fix dist checkpoint import --- torchdata/stateful_dataloader/ibm_rescalable.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 60c0ec331..05641446e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -5,7 +5,7 @@ from typing import Any, Callable, List import torch -import torch.distributed as dist +from torch.distributed import checkpoint import torch.distributed.tensor as dtensor import torch.utils.data as data @@ -569,8 +569,8 @@ def save_distributed_state_dict( state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) # Write distributed state dict - writer = dist.checkpoint.FileSystemWriter(path) - dist.checkpoint.save( + writer = checkpoint.FileSystemWriter(path) + checkpoint.save( dstate, writer, ) @@ -606,8 +606,8 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Read distributed state dict - reader = dist.checkpoint.FileSystemReader(path) - dist.checkpoint.load_state_dict( + reader = checkpoint.FileSystemReader(path) + checkpoint.load_state_dict( dstate, reader, ) From eac8ef61382ff2b701b94a39d237cb0d9571038a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:35:48 -0500 Subject: [PATCH 10/79] Check ckp subfolder existence, not working folder --- examples/ibm_rescaling/rescaling_demo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 241bc3848..73e63d4c3 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -64,8 +64,9 @@ data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) # If checkpoint does not exist, create it -if not os.path.exists(args.ckpt_path) or len(os.listdir(args.ckpt_path)) == 0: - os.makedirs(args.ckpt_path, exist_ok=True) +ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state") +if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0: + os.makedirs(ckpt_path, exist_ok=True) # Iterate, assemble values to exclude if rank == 0: print("No existing checkpoint. Processing 100 steps.") @@ -75,7 +76,7 @@ if i == 100: if rank == 0: print("Iteration complete!") - save_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + save_distributed_state_dict(data, ckpt_path, mesh) break avoid.append(inp) avoid = torch.cat(avoid) @@ -87,7 +88,7 @@ ).full_tensor() # Continue, assemble values to include - load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + load_distributed_state_dict(data, ckpt_path, mesh) if rank == 0: print("DCP state loaded!") @@ -114,7 +115,7 @@ else: if rank == 0: print("Checkpoint detected!") - load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + load_distributed_state_dict(data, ckpt_path, mesh) vals = [] for i, inp in enumerate(data): From afd01699c906d366e5902c2a8cae7f515eeedae2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:43:15 -0500 Subject: [PATCH 11/79] Save vals for checking --- examples/ibm_rescaling/rescaling_demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 73e63d4c3..1589d5325 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -130,6 +130,7 @@ if rank == 0: avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) include = torch.load(os.path.join(args.ckpt_path, "include.pth")) + torch.save(vals, os.path.join(args.ckpt_path, "vals.pth")) def _in(v, m): # Returns whether vector v is a row of matrix m (both tensors) From 031d67cb4e3fe568419e4b50f1684d986f8d175d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:50:54 -0500 Subject: [PATCH 12/79] Load dummy gen state always --- torchdata/stateful_dataloader/ibm_rescalable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 05641446e..fd62469c6 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -392,9 +392,8 @@ def state_dict(self): def load_state_dict(self, state_dict): super().load_state_dict(state_dict) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) + # Manually set generator state + self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) class ScalableShardDataset(_WrapperDataset): From d9a575bac58080eb8be2db28052734a19dd1b4d4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:58:51 -0500 Subject: [PATCH 13/79] Setup calls in dummy --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index fd62469c6..95ecb4c09 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -379,6 +379,7 @@ def setup(self): self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) def __iter__(self): + self.setup() while True: out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() @@ -386,6 +387,7 @@ def __iter__(self): yield out def state_dict(self): + self.setup() # Write generator state manually self.g_state = self.generator.get_state().tolist() return super().state_dict() From 157f90b2c43631f5abe9bd9f498666d959512078 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:19:40 -0500 Subject: [PATCH 14/79] Diag print --- torchdata/stateful_dataloader/ibm_rescalable.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 95ecb4c09..c3695f191 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -48,16 +48,6 @@ def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any return itemlist[(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize] -def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, - and return the span including all owned items, fractional or otherwise. - """ - start = math.floor(len(itemlist) * rank / worldsize) - end = math.ceil(len(itemlist) * (rank + 1) / worldsize) - return itemlist[start:end] - - class _StatefulDataset(data.IterableDataset): """ Stub for stateful datasets, extends data.IterableDataset with state_dict methods. @@ -384,6 +374,8 @@ def __iter__(self): out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() out[-1] = self.delimiter + if self.rank==0: + print(out) yield out def state_dict(self): From 91f1b148211a691a5bcea98242d542626a32c572 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:23:14 -0500 Subject: [PATCH 15/79] Remove sampling --- examples/ibm_rescaling/rescaling_demo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 1589d5325..5af0f51c9 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -49,13 +49,13 @@ # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] -data = SamplingDataset( - os.path.join(args.ckpt_path, "data"), - data, - delimiter_token=-1, - datasets=subdatas, - weights=[12, 17, 5], -) +# data = SamplingDataset( +# os.path.join(args.ckpt_path, "data"), +# data, +# delimiter_token=-1, +# datasets=subdatas, +# weights=[12, 17, 5], +# ) # Apply rescalability layer data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors From b3569e34f9337f365ae47bbea0a8c451541af8b1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:28:48 -0500 Subject: [PATCH 16/79] Path in dummy build --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 5af0f51c9..b11e9fd43 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,7 +45,7 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset("data", rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] From 0faea8c27a7d704f948aada11d9dcf02f9bd78c4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:31:20 -0500 Subject: [PATCH 17/79] Path in dummy build --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index b11e9fd43..43aa0099f 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,7 +45,7 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset("data", rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] From 0be44e40021831e99e669f4ddb5a2229f800b9ca Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:34:14 -0500 Subject: [PATCH 18/79] Scalable off --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 43aa0099f..ed98fa518 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -57,7 +57,7 @@ # weights=[12, 17, 5], # ) # Apply rescalability layer -data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +# data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader From c54aed2eba3fc5bdaa1ec914d7724e8cada57660 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:37:28 -0500 Subject: [PATCH 19/79] Build data folder early --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index ed98fa518..a29f86fba 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -43,12 +43,12 @@ dist.init_process_group() mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) placement = [dist.tensor.placement_types.Shard(0)] +subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] +[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] # Build dataloader data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets -subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] -[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] # data = SamplingDataset( # os.path.join(args.ckpt_path, "data"), # data, From a16ffb17042b6e472cf0ba1844bb7f03a5b6102d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:46:14 -0500 Subject: [PATCH 20/79] Avoid resetting gen each state dict call --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index c3695f191..ea53c92dc 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -366,7 +366,8 @@ def __init__( def setup(self): super().setup() - self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) + if self.generator is None: + self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) def __iter__(self): self.setup() From b645aeaa0e4dc80130d846c065da73565b0165f2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:48:42 -0500 Subject: [PATCH 21/79] Diag print off, all datasets on --- examples/ibm_rescaling/rescaling_demo.py | 16 ++++++++-------- torchdata/stateful_dataloader/ibm_rescalable.py | 2 -- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index a29f86fba..c8b15aaac 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -49,15 +49,15 @@ # Build dataloader data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets -# data = SamplingDataset( -# os.path.join(args.ckpt_path, "data"), -# data, -# delimiter_token=-1, -# datasets=subdatas, -# weights=[12, 17, 5], -# ) +data = SamplingDataset( + os.path.join(args.ckpt_path, "data"), + data, + delimiter_token=-1, + datasets=subdatas, + weights=[12, 17, 5], +) # Apply rescalability layer -# data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ea53c92dc..cbe5dd17c 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -375,8 +375,6 @@ def __iter__(self): out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() out[-1] = self.delimiter - if self.rank==0: - print(out) yield out def state_dict(self): From ceffd247f48459a33bb104da6724dd8ea6500652 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:57:27 -0500 Subject: [PATCH 22/79] Stop saving vals --- examples/ibm_rescaling/rescaling_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index c8b15aaac..2bb4a6bbb 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -130,7 +130,6 @@ if rank == 0: avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) include = torch.load(os.path.join(args.ckpt_path, "include.pth")) - torch.save(vals, os.path.join(args.ckpt_path, "vals.pth")) def _in(v, m): # Returns whether vector v is a row of matrix m (both tensors) From d2eb12ef48b6240eddcf184bf6c83859fd7403ed Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:18:56 -0500 Subject: [PATCH 23/79] Attempt single blob save --- torchdata/stateful_dataloader/ibm_rescalable.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index cbe5dd17c..23338ba81 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -560,14 +560,15 @@ def save_distributed_state_dict( rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) + out = {"state":state, "dstate":dstate} # Write distributed state dict writer = checkpoint.FileSystemWriter(path) checkpoint.save( - dstate, + out, writer, ) - # Write nondistributed state dict - torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) + # # Write nondistributed state dict + # torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) def load_distributed_state_dict( From ada91ec02647a85575b91390d470f4e8fab64a13 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:28:14 -0500 Subject: [PATCH 24/79] Attempt single blob load --- .../stateful_dataloader/ibm_rescalable.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 23338ba81..060395904 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -587,23 +587,24 @@ def load_distributed_state_dict( nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] rank = loader.dataset.rank dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) + inp = {"state":base, "dstate":dstate} # Read nondistributed state dict - ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x]) + ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) + # Read distributed state dict + reader = checkpoint.FileSystemReader(path) + checkpoint.load_state_dict( + inp, + reader, + ) + dstate = inp["dstate"] # Check that number of loaders matches if ckp_ws == loader.dataset.worldsize: - state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth")) # Check that number of workers matches if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: - state = base + state = inp["state"] else: # On mismatch, discard saved non-reshardable loader state and start fresh state = base - # Read distributed state dict - reader = checkpoint.FileSystemReader(path) - checkpoint.load_state_dict( - dstate, - reader, - ) # Get local tensors from dtensors, and slice over workers dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()} # Flip dict[list[tensor]] to list[dict[tensor]] From 9bf8f3d61927e42158955160ce2b818ff383cbb6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:37:56 -0500 Subject: [PATCH 25/79] Prevent loading in place --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 060395904..4d789541f 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -585,9 +585,8 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - rank = loader.dataset.rank dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) - inp = {"state":base, "dstate":dstate} + inp = {"state":deepcopy(base), "dstate":dstate} # Read nondistributed state dict ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) # Read distributed state dict From 934d37b5995e025004dc9692fa4ecc8c45482e41 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:55:39 -0500 Subject: [PATCH 26/79] Cleanup --- torchdata/stateful_dataloader/ibm_rescalable.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 4d789541f..f08f9c6ef 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -557,7 +557,6 @@ def save_distributed_state_dict( Rescalable worker states are compiled into a dtensor across ranks, and saved using pytorch distributed checkpointing. """ - rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) out = {"state":state, "dstate":dstate} @@ -567,8 +566,6 @@ def save_distributed_state_dict( out, writer, ) - # # Write nondistributed state dict - # torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) def load_distributed_state_dict( @@ -587,8 +584,6 @@ def load_distributed_state_dict( nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) inp = {"state":deepcopy(base), "dstate":dstate} - # Read nondistributed state dict - ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) # Read distributed state dict reader = checkpoint.FileSystemReader(path) checkpoint.load_state_dict( @@ -597,6 +592,7 @@ def load_distributed_state_dict( ) dstate = inp["dstate"] # Check that number of loaders matches + ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) if ckp_ws == loader.dataset.worldsize: # Check that number of workers matches if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: From 8d0cfd8451106c54f1e1e015904b5d8f05564253 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 12:13:38 -0500 Subject: [PATCH 27/79] ScalableReader changes --- examples/ibm_rescaling/rescaling_demo.py | 121 ++-- .../stateful_dataloader/ibm_rescalable.py | 666 +++++++++--------- 2 files changed, 401 insertions(+), 386 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 2bb4a6bbb..f42b070fc 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -1,23 +1,23 @@ import argparse +import math import os - +import pyarrow as pa import torch from torch import distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.ibm_rescalable import ( - DummyDataset, + ArrowHandler, PreprocessDataset, - SamplingDataset, - ScalableShardDataset, + ScalableReader, load_distributed_state_dict, save_distributed_state_dict, ) # This example script validates the rescaling behavior of the ibm rescalable distributed datasets. -# On first run, saves a distributed checkpoint to the desired location. +# On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location. # On subsequent runs, loads the checkpoint (possibly on a different world size / num workers) -# and verifies that previous data is not revisited, while upcoming data is. +# and verifies that all remaining data is covered by the time the epoch finishes. # Example usage: # torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6 @@ -28,36 +28,50 @@ parser.add_argument( "--logical_shards", type=int, - default=96, - help="Total number of data partitions. (worldsize * n_workers) must divide this evenly.", + default=350, + help="Total number of data partitions. Must exceed (worldsize * n_workers) but not n_docs (1000).", ) parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device") -parser.add_argument("--b_size", type=int, default=1, help="Number of data points per step per device") +parser.add_argument("--b_size", type=int, default=2, help="Number of data points per step per device") +parser.add_argument("--n_steps", type=int, default=50, help="Number of steps to take before saving. (n_steps * b_size * worldsize) cannot exceed number of items in epoch (3000)") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() + # Setup rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) dist.init_process_group() mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) placement = [dist.tensor.placement_types.Shard(0)] -subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] -[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] + +# Check input args +assert args.logical_shards >= world_size*args.num_workers, f"Logical shards {args.logical_shards} cannot be less than total workers {world_size*args.num_workers}" +assert args.logical_shards <= 1000, f"Logical shards {args.logical_shards} cannot exceed number of documents 1000" +assert args.n_steps*args.b_size*world_size < 3000, f"Number of items drawn before saving {args.n_steps*args.b_size*world_size} cannot exceed number of document chunks 3000." + +# Build dataset +datapath = os.path.join(args.ckpt_path, "dataset") +if not os.path.exists(datapath): + os.mkdir(datapath) +schema = pa.schema([pa.field("tokens", pa.uint32())]) +with pa.ipc.new_file( + os.path.join(datapath, "fileshard_1.arrow"), schema +) as writer: + for i in range(500): + out = list(range(i * 100, i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + +with pa.ipc.new_file( + os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema +) as writer: + for i in range(500): + out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) # Build dataloader -data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) -# Pretend that we're sampling over multiple sub-datasets -data = SamplingDataset( - os.path.join(args.ckpt_path, "data"), - data, - delimiter_token=-1, - datasets=subdatas, - weights=[12, 17, 5], -) -# Apply rescalability layer -data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader @@ -69,16 +83,16 @@ os.makedirs(ckpt_path, exist_ok=True) # Iterate, assemble values to exclude if rank == 0: - print("No existing checkpoint. Processing 100 steps.") + print(f"No existing checkpoint. Processing {args.n_steps} steps.") avoid = [] for i, inp in enumerate(data): - if i == 100: + if i == args.n_steps: if rank == 0: print("Iteration complete!") save_distributed_state_dict(data, ckpt_path, mesh) break - avoid.append(inp) + avoid.append(inp[:,0]) avoid = torch.cat(avoid) # Get all vals onto each rank avoid = dist.tensor.DTensor.from_local( @@ -87,63 +101,46 @@ placement, ).full_tensor() - # Continue, assemble values to include - load_distributed_state_dict(data, ckpt_path, mesh) - if rank == 0: - print("DCP state loaded!") - - include = [] - for i, inp in enumerate(data): - if i == 10: - break - include.append(inp) - include = torch.cat(include) - if rank == 0: - print("Iteration round 2 complete!") - # Get all vals onto each rank - include = dist.tensor.DTensor.from_local(include, mesh, placement).full_tensor() - if rank == 0: torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth")) - torch.save(include, os.path.join(args.ckpt_path, "include.pth")) print( "Generation complete! Please rerun (with different world size / workers if desired) to complete the check." ) -# If checkpoint does exist, load and take 100 steps. -# Ensure avoid values are avoided, and all include values are included. +# If checkpoint does exist, load and finish epoch. +# Ensure all expected values are covered once epoch concludes. else: if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) + avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() + # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) vals = [] + n_steps = ( + math.ceil((3000 - len(avoid)) / (world_size * args.num_workers)) + + 2 * math.ceil(1000/args.logical_shards) + ) for i, inp in enumerate(data): - if i == 100: + if i == n_steps: break vals.append(inp) vals = torch.cat(vals) # Get all vals onto each rank vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() - # Perform avoid/include checks on rank 0 only + # Perform data coverage check on rank 0 only if rank == 0: - avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) - include = torch.load(os.path.join(args.ckpt_path, "include.pth")) - - def _in(v, m): - # Returns whether vector v is a row of matrix m (both tensors) - return m.sub(v[None]).abs().sum(1).sign().prod().bool().logical_not().item() - - # Avoid check - for i, x in enumerate(avoid.split(1)): - assert not _in(x[0], vals), i - print("Check passed: seen data was not revisited!") - - # Include check - for i, x in enumerate(include.split(1)): - assert _in(x[0], vals), i - print("Check passed: upcoming data appears as expected!") + # Invert avoid to get expected vals + expect = [] + for i in range(1000): + for offset in [0,40,80]: + if i*100+offset not in avoid: + expect.append(i*100+offset) + + for x in expect: + assert x in vals, x + print("Check passed: upcoming data is covered as expected!") dist.barrier() dist.destroy_process_group() diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index f08f9c6ef..31c8cbd45 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -1,17 +1,21 @@ import logging import math import os +import pyarrow as pa from copy import deepcopy -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional, Set import torch from torch.distributed import checkpoint import torch.distributed.tensor as dtensor +import torch.distributed as dist import torch.utils.data as data from .stateful_dataloader import StatefulDataLoader """ +TODO: UPDATE THIS FOR SCALABLEREADER + The following distributed dataloaders are designed around 3 main principles: 1. Efficient, asynchronous operation. Workers on different devices do not communicate. @@ -41,13 +45,6 @@ """ -def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. - """ - return itemlist[(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize] - - class _StatefulDataset(data.IterableDataset): """ Stub for stateful datasets, extends data.IterableDataset with state_dict methods. @@ -176,6 +173,94 @@ def state_dict(self): return out +#### ------------------------- FILE READERS ------------------------- #### + + +class _ShardFileHandler: + """ + Stub for shard file readers of different formats. + Must implement open, length, indexing, and slicing functions. + """ + + def is_legal(self, filepath: str): + """ + Given a file path, determine if it qualifies for this handler. + Ideally does not involve opening the file. + """ + return os.path.isfile(filepath) + + def open(self, path: str): + """ + Open the file, to be indexed via self.get() method. + Avoid reading entire multi-Gb files when possible! + """ + raise NotImplementedError + + def length(self, path: str): + """ + Calculate the number of documents in the given file. + Avoid reading entire multi-Gb files when possible! + """ + raise NotImplementedError + + def get(self, reader, index: int, drop_tokens: Set): + """ + Given the output of self.open() and an index, return the document at that index. + Then, remove the first and/or last items if they appear in drop_tokens. + Try to avoid reading entire documents at a time in case of long documents, + but this is less important than avoiding reading entire files as above. + Output must support len() method. + """ + raise NotImplementedError + + def slice(self, doc, index: int, n_pull: int) -> List: + """ + Given a long document, retrieve n_pull consecutive items starting from index. + Again, try to be memory-efficient when doing so, but efficiency in self.get() + and self.open() is far more important. + Must return a python list. + """ + raise NotImplementedError + + +class ArrowHandler(_ShardFileHandler): + """ + Reader for indexable, pre-tokenized PyArrow shard files. + Pyarrow shard files are expected to hold multiple RecordBatches, + where each RecordBatch has a "tokens" field consisting of + a single token list (i.e. each document is a single sequence + under a "token" field, and the file is a list of such sequences). + + A preferred format as we can load document chunks without having to ever pull + the entire document or shard file, allowing for graceful handling of large documents. + Non-standard data format, though. + """ + + def __init__(self, col_name: str = "tokens"): + self.col_name = col_name + + def is_legal(self, filepath: str): + return "arrow" in os.path.splitext(filepath)[1] + + def open(self, path: str): + return pa.ipc.open_file(pa.memory_map(path)) + + def length(self, path: str): + return self.open(path).num_record_batches + + def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): + doc = reader.get_batch(index)[self.col_name] + if len(doc) > 0 and doc[0].as_py() in drop_tokens: + doc = doc.slice(1, len(doc) - 1) + # Recheck len for edge case where doc=[eos] + if len(doc) > 0 and doc[-1].as_py() in drop_tokens: + doc = doc.slice(0, len(doc) - 1) + return doc + + def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: + return doc.slice(index, n_pull).to_pylist() + + #### ------------------------- DATASET LAYERS ------------------------- #### @@ -207,358 +292,295 @@ def __iter__(self): yield self.aug_fn(out) -class SamplingDataset(_WrapperDataset): +class ScalableReader(_StatefulDataset): """ - A _WrapperDataset implementing percentage-based sampling: weights can be floats, and the - number of tokens seen from each subdataset will match those weights as closely as possible. - This is accomplished by maintaining a _StatefulDataset for each subdataset, and tracking - the number of tokens emitted by each. Whichever loader is furthest from its target will be - the next to pass a document. - Relies on eos token to determine document boundaries, so must sit below BufferDataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects directory to contain subfolders, - which in turn contain shard files. - dataset : _StatefulDataset - Fully instantiated dataset. Cloned across desired subdatasets during setup. - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - datasets : list[str] | None - A list of subdatasets to draw from. If None, draws from all subfolders of datapath. - weights : list(float) | None - Weights describing what percent of emitted tokens should come from each subdataset. - Need not sum to 1. If None, tokens are drawn evenly. - verbose : bool - Track setup progress? + Maintains shared logical shards but opens them one at a time. Completely repartitions + unseen shards only when rescaling. """ def __init__( - self, - datapath: str, - dataset: _StatefulDataset, - delimiter_token: Any, - datasets=None, - weights=None, - verbose=False, - ): - super().__init__(dataset) - self.datapath = datapath - self.delimiter = delimiter_token - self.verbose = verbose - self.datasets = ( - datasets - if datasets is not None - else [f for f in os.listdir(datapath) if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f] - ) - assert len(self.datasets) > 0, "You must specify at least one dataset" - - if weights is not None: - assert len(weights) == len( - self.datasets - ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" - for w in weights: - assert w > 0, f"Sampling rate {w} must be positive" - self.weights = [1] * len(self.datasets) if weights is None else weights - self.weights = [w / sum(self.weights) for w in self.weights] - - self.tokens_seen = [0] * len(self.datasets) - - self.current_iterator = -1 - self.state_params = ["tokens_seen", "current_iterator"] - - def setup(self): - if not self.is_setup: - _StatefulDataset.setup(self) - # Build subdataset iterators - self.data = [] - for i, d in enumerate(self.datasets): - self.data.append(deepcopy(self.dataset)) - self.data[-1].datapath = os.path.join(self.datapath, d) - self.data[-1].rank = self.rank - self.data[-1].worldsize = self.worldsize - self.data[-1].local_worldsize = self.local_worldsize - if self.verbose: - logging.info( - f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" - ) - [d.setup() for d in self.data] - - def __iter__(self): - self.setup() - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - if self.current_iterator != -1: - # Finish current document - out = next(data[self.current_iterator]) - self.tokens_seen[self.current_iterator] += len(out) - if out[-1] == self.delimiter: - self.current_iterator = -1 - yield out - else: - # Choose new subdataset to draw from - # (whichever is currently most underrepresented compared to target rate) - offset = [ - self.weights[i] - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) - for i in range(len(self.datasets)) - ] - offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] - self.current_iterator = offset_argmax - - def state_dict(self): - self.setup() - # Manually add state of all subloaders to self state - iterator_states = [d.state_dict() for d in self.data] - assert len(iterator_states) > 0, f"Worker {self.rank} owns no datasets" - # Flip list[dict[any]] to dict[list[any]] - prefix = self.statename("states.") - out = {prefix + k: [d[k] for d in iterator_states] for k in iterator_states[0].keys()} - out.update(_StatefulDataset.state_dict(self)) - return out - - def load_state_dict(self, state_dict): - self.setup() - # Load stats - _StatefulDataset.load_state_dict(self, state_dict) - # Load sub-iterator states - prefix = self.statename("states.") - # Flip dict[list[any]] to list[dict[any]] - iterator_states = [ - {k[k.find(prefix) + len(prefix) :]: v[i] for k, v in state_dict.items() if prefix in k} - for i in range(len(self.data)) - ] - # Load individual state sub-dicts - [self.data[i].load_state_dict(iterator_states[i]) for i in range(len(self.data))] - - -class DummyDataset(_StatefulDataset): - """ - A dummy base dataset for demo purposes. - - Normally this dataset would be responsible for using rank, datapath and worldsize arguments - to perform dataset partitioning, and implement repeating iteration over its particular data shard. - - Spits out random sequences of desired vocab size / seq length as lists. - Places delimiter token at end of each sequence (used by SamplingDataset). - """ - - def __init__( - self, - datapath: str, - rank: int, + self, + datapath: str, + rank: int, worldsize: int, + filehandler: _ShardFileHandler, delimiter_token: Any, + bos_token: Optional[Any] = None, + strip_tokens: Optional[Set[Any]] = set(), seed: int = 42, - vocab: int = 100, - seqlen: int = 64, + min_length: int = 1, + max_chunksize: int = 1024, + n_logical_shards: int = 30720, + verbose: bool = False, ): super().__init__(datapath, rank, worldsize) - self.vocab = vocab - self.seqlen = seqlen - self.delimiter = delimiter_token - # Ensure different seeds across ranks and datasets, for demo purposes self.seed = seed - self.generator = None - self.g_state = None - self.state_params = ["g_state"] - - def setup(self): - super().setup() - if self.generator is None: - self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) - - def __iter__(self): - self.setup() - while True: - out = torch.rand(self.seqlen, generator=self.generator) - out = out.mul(self.vocab).int().tolist() - out[-1] = self.delimiter - yield out - - def state_dict(self): - self.setup() - # Write generator state manually - self.g_state = self.generator.get_state().tolist() - return super().state_dict() - - def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) - # Manually set generator state - self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) - - -class ScalableShardDataset(_WrapperDataset): - """ - A _WrapperDataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of smaller StatefulDatasets, cloned from the - original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. - Rescaling only works when this layer wraps all other layers that contribute to state_dict. - ... - Args - ---- - dataset : _StatefulDataset - Fully instantiated dataset. Cloned into logical workers during setup fn. - n_logical_shards : int - Total number of logical shards. Must be a multiple of world size. - verbose : bool - Track setup progress? - """ - - def __init__( - self, - dataset: _StatefulDataset, - n_logical_shards: int = 2048, - verbose=False, - ): - super().__init__(dataset) - assert ( - n_logical_shards % self.worldsize == 0 - ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert n_logical_shards > 0, f"n_logical_shards {n_logical_shards} must be a positive integer" - - self.total_shards = n_logical_shards + self.datapath = datapath + self.filehandler = filehandler() + self.min_length = min_length + assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" + self.chunksize = max_chunksize + self.eos = delimiter_token + self.bos = bos_token + self.drop = strip_tokens + self.n_logical_shards = n_logical_shards self.verbose = verbose + + # Position + self.reader = None + self.cur_file = None - # Fields to be populated during setup / subdataset setup - self.data: List[_StatefulDataset] = [] - self.logicals_owned: List[int] = [] - self.n_logicals = 0 - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = 0 - self.load_worldsize = self.worldsize - - self.state_params = ["current_reader"] # self.data states are handled manually + # Setup flags + self.is_setup = False + self.filesizes = None # [[filenames], [filesizes]] CONSTRUCTED PRE ITER IF NOT LOADED + self.shard_states = None # shardid, file pos, doc pos, chunk pos, epoch RESHARD + + # TODO: add handling to prevent zero-length allocations + + def _get_shard_breakdown(self, rank, nshards): + # Find highest fileid still smaller than start + sizelist = torch.tensor(self.filesizes[1]) + sizelist = sizelist/sizelist.float().mean() + cum_sizelist = sizelist.cumsum(0) + start_frac = rank/nshards*len(sizelist) + start_id = len(sizelist) - cum_sizelist.gt(start_frac).sum().item() + # For each doc, assign relevant fractional ownership + start = start_frac + end = (rank+1)/nshards*len(sizelist) + my_files = [] # fileid, start%, end% + for i, (size, cumsize_incl) in enumerate( + zip(sizelist[start_id:].tolist(), cum_sizelist[start_id:].tolist()) + ): + id = start_id + i + cumsize = cumsize_incl - size + if cumsize > end: + break + elif cumsize <= end and cumsize_incl >= start: + my_files.append([ + id, + min(max((start - cumsize) / size, 0), 1), + min(max((end - cumsize) / size, 0), 1), + ]) + return my_files def setup(self): if not self.is_setup: - _StatefulDataset.setup(self) - n_logical_shards = self.total_shards - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - self.n_logicals = n_logical_shards // self.worldsize - assert ( - len(self.logicals_owned) == self.n_logicals - ), "(world size * num workers) does not divide logical shards evenly" - - # Build logical shards - for i in range(self.n_logicals): - self.data.append(deepcopy(self.dataset)) - self.data[-1].worldsize = n_logical_shards - self.data[-1].rank = self.logicals_owned[i] - self.data[-1].local_worldsize = 1 - self.data[-1].datapath = self.datapath - self.data[-1].verbose = self.rank == 0 - if self.verbose: - logging.info( - f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - [d.setup() for d in self.data] + # Get your adjusted rank and worldsize + super().setup() + # Get logical shard partitions + my_shards = list(range( + (self.n_logical_shards * self.rank) // self.worldsize, + (self.n_logical_shards * (self.rank + 1)) // self.worldsize, + )) + + # Set up logical shard states (may be overwritten later by ckp load) + self.shard_states = torch.zeros(math.ceil(self.n_logical_shards / self.worldsize), 5, dtype=torch.int) + self.shard_states[:len(my_shards), 0] = torch.tensor(my_shards) + self.shard_states[len(my_shards):, 0] = -1 + self.shard_states[len(my_shards):, 4] = torch.iinfo(torch.int).max + + def _pre_iter(self): + # Run after loading checkpoint, before iterating + + # Assemble set of available shard files, if nonexistant + if self.filesizes is None: + # Find all legal files + shards = [ + [os.path.join(root,name)[len(self.datapath)+1:], os.path.getsize(os.path.join(root, name))] + for root, dirs, files in os.walk(self.datapath, topdown=False) + for name in files + if self.filehandler.is_legal(os.path.join(root, name)) + ] + shards.sort() + # Flip list of (shard,size) tuples into (shardlist,sizelist) + self.filesizes = list(zip(*shards)) + + def _get_reader(self, fileid, reader, ndocs): + """ + If new fileid does not match the current one, open a new reader on + the corresponding filepath. Also return the number of docs in the file. + """ + if self.cur_file == fileid: + return reader, ndocs + else: + self.cur_file = fileid + filepath = os.path.join(self.datapath, self.filesizes[0][fileid]) + return self.filehandler.open(filepath), self.filehandler.length(filepath) + + def _construct_chunk(self, j, doc, n_chunks): + """ + Grab a chunk of the desired size from the document, with eos/bos handling + """ + start_index = j * self.chunksize + n_pull = self.chunksize + if self.bos is not None: + if j == 0: + n_pull -= 1 + else: + start_index -= 1 + chunk = self.filehandler.slice(doc, start_index, n_pull) + # Add bos/eos tokens if needed + if self.bos is not None and j == 0: + chunk = [self.bos] + chunk + if j == n_chunks - 1: + chunk = chunk + [self.eos] + return chunk + def __iter__(self): - self.setup() - # Grab one item at a time, iterating over owned logical shards - data = [iter(d) for d in self.data] + if not self.is_setup: + self.setup() + self._pre_iter() + reader = None + ndocs = -1 while True: - ind = self.current_reader - # Read doc - out = next(data[ind]) - # Update state - self.current_reader = (self.current_reader + 1) % self.n_logicals - yield out + # Isolate undervisited shards + epoch_count = self.shard_states[:,4].min().item() + shardset = self.shard_states[:,4].eq(epoch_count).nonzero().squeeze(-1) + for i in shardset: + shardid = self.shard_states[i][0].item() + files = self._get_shard_breakdown(shardid, self.n_logical_shards) # list([docid, start%, end%]) + file_offset = self.shard_states[i][1].item() + for file_pos in range(file_offset, len(files)): + # Update position + self.shard_states[i][1] = file_pos + # Calculate doc range + file = files[file_pos] + fileid = file[0] + reader, ndocs = self._get_reader(fileid, reader, ndocs) + doc_start = round(ndocs * file[1]) + doc_end = round(ndocs * file[2]) + doc_offset = self.shard_states[i][2].item() + for doc_pos in range(doc_offset, doc_end - doc_start): + # Update position + self.shard_states[i][2] = doc_pos + # Fetch doc + doc = self.filehandler.get(reader, doc_start + doc_pos, self.drop) + doclen = len(doc) + nchunks = math.ceil(doclen/self.chunksize) + chunk_offset = self.shard_states[i][3].item() + for chunk_pos in range(chunk_offset, nchunks): + # Update position + self.shard_states[i][3] = chunk_pos+1 + # Yield chunk + yield torch.tensor(self._construct_chunk(chunk_pos, doc, nchunks)) # TODO: REMOVE TENSOR CALL!!! + # Reset chunk_pos after finishing doc + self.shard_states[i][3] = 0 + # Reset doc_pos after finishing file + self.shard_states[i][2] = 0 + # Reset file_pos after finishing shard + self.shard_states[i][1] = 0 + # Increase epoch count after finishing shard + self.shard_states[i][4] += 1 + # Begin new epoch def state_dict(self): self.setup() - # Recursive fetch - logical_shard_states = [d.state_dict() for d in self.data] - assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???" - # Flip list[dict[Any]] to dict[list[Any]] - state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()} - state_dict.update(_StatefulDataset.state_dict(self)) - - # Convert to tensor form - out = {} - for k, v in state_dict.items(): - v = torch.tensor(v) - if len(v.shape) == 0: - k = k + ".scalar" - v = v.unsqueeze(0) - out[k] = v - - return out - + # Values to save: shard states (shard/repl), filesizes (single/repl) + # Deepcopy required to prevent in-place modification from later prefetches + return deepcopy({ + self.statename("shard_states"): self.shard_states.unsqueeze(0), + self.statename("file_info"): self.filesizes if self.rank == 0 else None + }) + def load_state_dict(self, state_dict): self.setup() - - # Convert back to lists and scalars - def detorchify(k, v): - v = v.tolist() - if ".scalar" in k: - k = k[:-7] - v = v[0] - return k, v - - plain_dict = {} - for k, v in state_dict.items(): - k, v = detorchify(k, v) - plain_dict[k] = v - state_dict = plain_dict - - # Assemble logical shard states - # TODO: how is this handling non-resharding state_params when resharding??? - _StatefulDataset.load_state_dict(self, state_dict) - # Remove all non-resharding state - [state_dict.pop(self.statename(n)) for n in self.state_params] - # Flip dict[list[any]] to list[dict[any]] - logical_shard_states = [{k: v[i] for k, v in state_dict.items()} for i in range(self.n_logicals)] - - # Load values - for i in range(self.n_logicals): - self.data[i].load_state_dict(logical_shard_states[i]) + # Load back shard states (global), filesizes (all) + shard_states = state_dict[self.statename("shard_states")] + file_info = state_dict[self.statename("file_info")] + if shard_states.size(0) == self.worldsize: + self.filesizes = file_info + self.shard_states = shard_states[self.rank] + else: + shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 + shard_states = torch.cat(shard_states, dim=0) # wn 5 + # Sort shards by epoch count + sorted, indices = torch.sort(shard_states[:,4], descending=True, stable=True) + shard_states = shard_states[indices] + # Strip out dummy shards + n_dummies = sorted.eq(torch.iinfo(torch.int).max).sum() + shard_states = shard_states[n_dummies:] # n_logical 5 + assert len(shard_states) == self.n_logical_shards, f"Number of shards {len(shard_states)} does not match specified {self.n_logical_shards}" + sorted = sorted[n_dummies:] + # Split into max and non-max epochs + n_complete = sorted.eq(sorted[0]).sum() + completed_shards = shard_states[:n_complete] + incomplete_shards = shard_states[n_complete:] + # Allocate completed shards + completed_shards = [ + completed_shards[ + round(i*len(completed_shards)/self.worldsize): + round((i+1)*len(completed_shards)/self.worldsize) + ] for i in range(self.worldsize) + ] + # Sort completed shards by length + completed_shards.sort(key=len) + # Allocate incomplete shards + incomplete_shards = [ + incomplete_shards[ + round(i*len(incomplete_shards)/self.worldsize): + round((i+1)*len(incomplete_shards)/self.worldsize) + ] for i in range(self.worldsize) + ] + # Reverse sort incomplete shards by length + incomplete_shards.sort(key=len, reverse=True) + # Pull out shard allocation for this worker + # (sort/reverse-sort ensures allocations are off by no more than 1) + shard_states = torch.cat([ + completed_shards[self.rank], + incomplete_shards[self.rank] + ]) + # Order shards by global ID (for steady file progression) + _, indices = shard_states[:,0].sort() + self.shard_states[:len(shard_states)] = shard_states[indices] + # Pad out with dummy shards if needed + self.shard_states[len(shard_states):,0] = -1 + self.shard_states[len(shard_states):,4] = torch.iinfo(torch.int).max + return None #### ------------------------- CHECKPOINT FUNCTIONS ------------------------- #### -def __pop_dstate(state, device_mesh, placements): +def __pop_dstate(state, device_mesh, placements, create_dtensor=False): """ Removes worker states from the StatefulDataLoader state dict, and assembles them - into a separate list of dicts for distributed checkpointing. + into a separate list of dicts of dtensors for distributed checkpointing. """ dstate = state["_snapshot"]["_worker_snapshots"] dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))] # Flip list[dict[tensor]] to dict[list[tensor]], and concat - dstate = {k: torch.cat([d[k] for d in dstate], 0) for k in dstate[0]} - # Construct dtensors from tensors - dstate = { - k: dtensor.DTensor.from_local( - v, + shardstate = "ScalableReader.shard_states" + fileinfo = "ScalableReader.file_info" + dstate_dict = { + shardstate: torch.cat([d[shardstate] for d in dstate], 0) + } + if create_dtensor == True: + dstate_dict[shardstate] = dtensor.DTensor.from_local( + dstate_dict[shardstate], device_mesh, placements, ) - for k, v in dstate.items() - } - return dstate + dstate_dict[fileinfo] = dstate[0][fileinfo] + return dstate_dict def save_distributed_state_dict( loader: StatefulDataLoader, path: str, - device_mesh=None, + device_mesh: dist.DeviceMesh, ): """ Retrieves dataloader state dict, and separates worker states from loader state. - Loader state is not rescalable, and is saved using normal torch.save. - It is discarded when rescaling. + Loader state is not rescalable, and is discarded when rescaling. Rescalable worker states are compiled into a dtensor across ranks, and saved using pytorch distributed checkpointing. """ state = deepcopy(loader.state_dict()) - dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) + dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)], True) + # Prune empty fileinfos + if dstate["ScalableReader.file_info"] is None: + dstate.pop("ScalableReader.file_info") out = {"state":state, "dstate":dstate} # Write distributed state dict writer = checkpoint.FileSystemWriter(path) @@ -571,18 +593,18 @@ def save_distributed_state_dict( def load_distributed_state_dict( loader: StatefulDataLoader, path: str, - device_mesh=None, + device_mesh: dist.DeviceMesh, ): """ Retrieves dataloader state dict, and separates worker states from loader state. If not rescaling, load saved dataloader state. Rescalable worker states are retrieved using pytorch distributed checkpointing. - States are distributed over workers, and ScalableShardDataset will handle + States are replicated over workers, and ScalableReader will handle partitioning and re-assignment of available states into logical ranks. """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate(0)], True) inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) @@ -591,19 +613,15 @@ def load_distributed_state_dict( reader, ) dstate = inp["dstate"] - # Check that number of loaders matches - ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) - if ckp_ws == loader.dataset.worldsize: - # Check that number of workers matches - if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: - state = inp["state"] + # Check that number of workers matches + ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.lisdtdir(path) if "loader" in x]) + if ckp_ws == loader.dataset.worldsize and nworkers == state["_snapshot"]["_main_snapshot"]["_num_workers"]: + state = inp["state"] else: # On mismatch, discard saved non-reshardable loader state and start fresh state = base - # Get local tensors from dtensors, and slice over workers - dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()} - # Flip dict[list[tensor]] to list[dict[tensor]] - dstate = [{k: v[i] for k, v in dstate.items()} for i in range(nworkers)] + # Repeat global tensor over all workers + dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): state["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"] = dstate[i] From e633e6003c44d8a15bb2c66a5234af38d14363fc Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 13:29:45 -0500 Subject: [PATCH 28/79] Fix datapath folder creation --- examples/ibm_rescaling/rescaling_demo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index f42b070fc..81eddf654 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -22,6 +22,7 @@ # Example usage: # torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6 +# Do not change the batch size or number of steps between the first and second runs! parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints") parser.add_argument("--ckpt_path", type=str, default="./rescale_test") @@ -54,7 +55,7 @@ # Build dataset datapath = os.path.join(args.ckpt_path, "dataset") if not os.path.exists(datapath): - os.mkdir(datapath) + os.makedirs(datapath) schema = pa.schema([pa.field("tokens", pa.uint32())]) with pa.ipc.new_file( os.path.join(datapath, "fileshard_1.arrow"), schema From 1f2e37aa9f9399539fbd6a84396832d7ebfd58b3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 13:35:30 -0500 Subject: [PATCH 29/79] Create datapath subfolder, data only when nonexistent --- examples/ibm_rescaling/rescaling_demo.py | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 81eddf654..5960d64ce 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -56,20 +56,20 @@ datapath = os.path.join(args.ckpt_path, "dataset") if not os.path.exists(datapath): os.makedirs(datapath) -schema = pa.schema([pa.field("tokens", pa.uint32())]) -with pa.ipc.new_file( - os.path.join(datapath, "fileshard_1.arrow"), schema -) as writer: - for i in range(500): - out = list(range(i * 100, i * 100 + 100)) - writer.write(pa.record_batch([out], schema=schema)) - -with pa.ipc.new_file( - os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema -) as writer: - for i in range(500): - out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) - writer.write(pa.record_batch([out], schema=schema)) + schema = pa.schema([pa.field("tokens", pa.uint32())]) + with pa.ipc.new_file( + os.path.join(datapath, "fileshard_1.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(i * 100, i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + os.makedirs(os.path.join(datapath, "subfolder")) + with pa.ipc.new_file( + os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) # Build dataloader data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) From 0acdf05c1633fa7fc63362bb415f7a5a583ed7b6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 13:43:53 -0500 Subject: [PATCH 30/79] Build data only rank 0 --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 5960d64ce..88def6f91 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -54,7 +54,7 @@ # Build dataset datapath = os.path.join(args.ckpt_path, "dataset") -if not os.path.exists(datapath): +if rank==0 and not os.path.exists(datapath): os.makedirs(datapath) schema = pa.schema([pa.field("tokens", pa.uint32())]) with pa.ipc.new_file( From d1460171c080996e977fa4f59b815f3059f92570 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 13:50:32 -0500 Subject: [PATCH 31/79] Pad chunks to make batchable --- examples/ibm_rescaling/rescaling_demo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 88def6f91..65013f12d 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -73,6 +73,8 @@ # Build dataloader data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) +# Pad entries to make them batch-able +data = PreprocessDataset(data, lambda x: x + [-1]*(30-len)) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader From 0fd38e8015897fce42345ed5f6447cf42bdb715e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 13:58:46 -0500 Subject: [PATCH 32/79] give time for data to construct --- examples/ibm_rescaling/rescaling_demo.py | 37 ++++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 65013f12d..0b9fbf4da 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -2,6 +2,7 @@ import math import os import pyarrow as pa +import time import torch from torch import distributed as dist @@ -54,22 +55,26 @@ # Build dataset datapath = os.path.join(args.ckpt_path, "dataset") -if rank==0 and not os.path.exists(datapath): - os.makedirs(datapath) - schema = pa.schema([pa.field("tokens", pa.uint32())]) - with pa.ipc.new_file( - os.path.join(datapath, "fileshard_1.arrow"), schema - ) as writer: - for i in range(500): - out = list(range(i * 100, i * 100 + 100)) - writer.write(pa.record_batch([out], schema=schema)) - os.makedirs(os.path.join(datapath, "subfolder")) - with pa.ipc.new_file( - os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema - ) as writer: - for i in range(500): - out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) - writer.write(pa.record_batch([out], schema=schema)) +if not os.path.exists(datapath): + if rank == 0: + os.makedirs(datapath) + schema = pa.schema([pa.field("tokens", pa.uint32())]) + with pa.ipc.new_file( + os.path.join(datapath, "fileshard_1.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(i * 100, i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + os.makedirs(os.path.join(datapath, "subfolder")) + with pa.ipc.new_file( + os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + else: + # Give other ranks time for worker 0 to finish + time.sleep(5) # Build dataloader data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) From e000b819939c1e0d0dff16b208382ec4f8402da7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:02:51 -0500 Subject: [PATCH 33/79] Fix pad fn --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 0b9fbf4da..755d8cb11 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -79,7 +79,7 @@ # Build dataloader data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) # Pad entries to make them batch-able -data = PreprocessDataset(data, lambda x: x + [-1]*(30-len)) +data = PreprocessDataset(data, lambda x: x + [-1]*(30-len(x))) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader From 5bbd0d1084e3b4222a32d91aa057e1535fe178d3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:06:06 -0500 Subject: [PATCH 34/79] reader yield list not tensor --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 31c8cbd45..7364bc480 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -463,7 +463,7 @@ def __iter__(self): # Update position self.shard_states[i][3] = chunk_pos+1 # Yield chunk - yield torch.tensor(self._construct_chunk(chunk_pos, doc, nchunks)) # TODO: REMOVE TENSOR CALL!!! + yield self._construct_chunk(chunk_pos, doc, nchunks) # Reset chunk_pos after finishing doc self.shard_states[i][3] = 0 # Reset doc_pos after finishing file From 888bc19e92edb03a4b49f28169ba32f0f4c5d694 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:13:33 -0500 Subject: [PATCH 35/79] No arg for repl placement --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 7364bc480..935e6f548 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -604,7 +604,7 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate(0)], True) + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate()], True) inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) From 9c1699dcf3a7f74f43155786a8da502c1b2e2585 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:19:46 -0500 Subject: [PATCH 36/79] typo fix --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 935e6f548..dc7814fc1 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -614,7 +614,7 @@ def load_distributed_state_dict( ) dstate = inp["dstate"] # Check that number of workers matches - ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.lisdtdir(path) if "loader" in x]) + ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "loader" in x]) if ckp_ws == loader.dataset.worldsize and nworkers == state["_snapshot"]["_main_snapshot"]["_num_workers"]: state = inp["state"] else: From c551a0757b82b74d46cdf6b43f220879982921ef Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:26:28 -0500 Subject: [PATCH 37/79] De-dtensorfy in load --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index dc7814fc1..4b196a7af 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -613,6 +613,8 @@ def load_distributed_state_dict( reader, ) dstate = inp["dstate"] + # De-DTensor-fy the shard states + dstate["ScalableReader.shard_states"] = dstate["ScalableReader.shard_states"].to_local() # Check that number of workers matches ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "loader" in x]) if ckp_ws == loader.dataset.worldsize and nworkers == state["_snapshot"]["_main_snapshot"]["_num_workers"]: From 4675681c5410c11920d0377b4f47978b3b029569 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:30:11 -0500 Subject: [PATCH 38/79] Full tensor (apparently replicated doesn't force on load) --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 4b196a7af..3b1258999 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -614,7 +614,7 @@ def load_distributed_state_dict( ) dstate = inp["dstate"] # De-DTensor-fy the shard states - dstate["ScalableReader.shard_states"] = dstate["ScalableReader.shard_states"].to_local() + dstate["ScalableReader.shard_states"] = dstate["ScalableReader.shard_states"].full_tensor() # Check that number of workers matches ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "loader" in x]) if ckp_ws == loader.dataset.worldsize and nworkers == state["_snapshot"]["_main_snapshot"]["_num_workers"]: From 65744ac6f44eed4a572a0e33d86f03902d3d6a7b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:33:18 -0500 Subject: [PATCH 39/79] Shard load, full tensor sendaround --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 3b1258999..f68204ff5 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -604,7 +604,7 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate()], True) + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)], True) inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) From 88ab3c74263274be91741483db981785c9ef169a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:37:39 -0500 Subject: [PATCH 40/79] Chunksize 40 --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 755d8cb11..0e8ab516e 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -77,9 +77,9 @@ time.sleep(5) # Build dataloader -data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=30, n_logical_shards=args.logical_shards) +data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=40, n_logical_shards=args.logical_shards) # Pad entries to make them batch-able -data = PreprocessDataset(data, lambda x: x + [-1]*(30-len(x))) +data = PreprocessDataset(data, lambda x: x + [-1]*(40-len(x))) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader From a34a5fc23f8084827d213135451b6808cc59f79f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:52:15 -0500 Subject: [PATCH 41/79] Intermediate diag mkdir --- examples/ibm_rescaling/rescaling_demo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 0e8ab516e..4ca7207f9 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -137,6 +137,10 @@ # Get all vals onto each rank vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() + # Diag save + os.makedirs(os.path.join(args.ckpt_path, "diag")) + torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) + # Perform data coverage check on rank 0 only if rank == 0: # Invert avoid to get expected vals From 763f60eff8079514b90393bb73b0d5ef3e689ebb Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:58:53 -0500 Subject: [PATCH 42/79] Time for other ranks to save --- examples/ibm_rescaling/rescaling_demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 4ca7207f9..b65bc79a3 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -140,6 +140,7 @@ # Diag save os.makedirs(os.path.join(args.ckpt_path, "diag")) torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) + time.sleep(10) # Perform data coverage check on rank 0 only if rank == 0: From 476c5a6fbdcc975a26ac1876898d1c66a2619143 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 15:01:17 -0500 Subject: [PATCH 43/79] exist ok diag subf --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index b65bc79a3..38b72d0f5 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -138,7 +138,7 @@ vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() # Diag save - os.makedirs(os.path.join(args.ckpt_path, "diag")) + os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True) torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) time.sleep(10) From ba00c20869856873262eea102c511f8503824c70 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 15:27:00 -0500 Subject: [PATCH 44/79] Corrected step counting --- examples/ibm_rescaling/rescaling_demo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 38b72d0f5..24b512920 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -95,12 +95,12 @@ avoid = [] for i, inp in enumerate(data): - if i == args.n_steps: + avoid.append(inp[:,0]) + if i == args.n_steps-1: if rank == 0: print("Iteration complete!") save_distributed_state_dict(data, ckpt_path, mesh) break - avoid.append(inp[:,0]) avoid = torch.cat(avoid) # Get all vals onto each rank avoid = dist.tensor.DTensor.from_local( @@ -140,6 +140,8 @@ # Diag save os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True) torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) + if rank == 0: + torch.save(vals, os.path.join(args.ckpt_path, "diag", "vals.pth")) time.sleep(10) # Perform data coverage check on rank 0 only From 0fd2b154f6a1ff6a228368412cbe715e58f15bad Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 13:56:59 -0500 Subject: [PATCH 45/79] Fix followup nstep scaling --- examples/ibm_rescaling/rescaling_demo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 24b512920..8593c60fd 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -85,6 +85,10 @@ # Wrap in StatefulDataLoader data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) +# TODO: debug: can't change n_workers when reloading - keyerror +# TODO: debug: going from 4/2/2 gpu/bsize/workers to 6/1/2 causes epoch not to finish + + # If checkpoint does not exist, create it ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state") if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0: @@ -126,7 +130,7 @@ # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) vals = [] n_steps = ( - math.ceil((3000 - len(avoid)) / (world_size * args.num_workers)) + math.ceil((3000 - len(avoid)) / (world_size * args.b_size)) + 2 * math.ceil(1000/args.logical_shards) ) for i, inp in enumerate(data): From fcfee8945902bf1655349227f037d68aee3ecea9 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 14:55:16 -0500 Subject: [PATCH 46/79] diag print --- torchdata/stateful_dataloader/ibm_rescalable.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index f68204ff5..882ff6132 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -524,6 +524,10 @@ def load_state_dict(self, state_dict): ] # Reverse sort incomplete shards by length incomplete_shards.sort(key=len, reverse=True) + + if self.rank == 0: + print([len(x) for x in completed_shards], [len(x) for x in incomplete_shards]) + # Pull out shard allocation for this worker # (sort/reverse-sort ensures allocations are off by no more than 1) shard_states = torch.cat([ From 57164ca28372400405198374cb29d7c4014cda1f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:18:36 -0500 Subject: [PATCH 47/79] diag print2 --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 882ff6132..b681c0d1e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -525,8 +525,7 @@ def load_state_dict(self, state_dict): # Reverse sort incomplete shards by length incomplete_shards.sort(key=len, reverse=True) - if self.rank == 0: - print([len(x) for x in completed_shards], [len(x) for x in incomplete_shards]) + print("shardlen", [len(x) for x in completed_shards], [len(x) for x in incomplete_shards]) # Pull out shard allocation for this worker # (sort/reverse-sort ensures allocations are off by no more than 1) From 068ab322156f62c07d4683aee4d891b17cd0c44f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:32:16 -0500 Subject: [PATCH 48/79] diag print3 --- examples/ibm_rescaling/rescaling_demo.py | 1 + torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 8593c60fd..e614c14f4 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -125,6 +125,7 @@ if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) + time.sleep(10) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index b681c0d1e..594f9f9c6 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -484,6 +484,7 @@ def state_dict(self): }) def load_state_dict(self, state_dict): + print("GOTHERE 2") self.setup() # Load back shard states (global), filesizes (all) shard_states = state_dict[self.statename("shard_states")] @@ -631,4 +632,5 @@ def load_distributed_state_dict( for i in range(nworkers): state["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"] = dstate[i] # Load into loader + print("GOTHERE") loader.load_state_dict(state) From dd7d569834e20a48ea4c2e2b3172a866491a4ff4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:34:45 -0500 Subject: [PATCH 49/79] diag print4 --- torchdata/stateful_dataloader/ibm_rescalable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 594f9f9c6..e2bb99168 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -493,6 +493,7 @@ def load_state_dict(self, state_dict): self.filesizes = file_info self.shard_states = shard_states[self.rank] else: + print("GOTHERE 3") shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 shard_states = torch.cat(shard_states, dim=0) # wn 5 # Sort shards by epoch count From 7fa868fc2b9aa9d3a53005fb3c3925f48d67c877 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:37:52 -0500 Subject: [PATCH 50/79] diag print5 --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index e2bb99168..95134457a 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -489,6 +489,7 @@ def load_state_dict(self, state_dict): # Load back shard states (global), filesizes (all) shard_states = state_dict[self.statename("shard_states")] file_info = state_dict[self.statename("file_info")] + print(shard_states.size(0), self.worldsize) if shard_states.size(0) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] @@ -541,6 +542,7 @@ def load_state_dict(self, state_dict): # Pad out with dummy shards if needed self.shard_states[len(shard_states):,0] = -1 self.shard_states[len(shard_states):,4] = torch.iinfo(torch.int).max + print("GOTHERE 4") return None From 473e9ff9dee086086ab158045f23bb3f424c0cd1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:43:46 -0500 Subject: [PATCH 51/79] diag print6 --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 95134457a..595d6fc20 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -490,6 +490,8 @@ def load_state_dict(self, state_dict): shard_states = state_dict[self.statename("shard_states")] file_info = state_dict[self.statename("file_info")] print(shard_states.size(0), self.worldsize) + if self.rank == 0: + print(shard_states) if shard_states.size(0) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] From bf22ce9282ba15d3821501a68aec884ad72b4c1a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 15:44:43 -0500 Subject: [PATCH 52/79] diag print7 --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 595d6fc20..745d89268 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -491,7 +491,7 @@ def load_state_dict(self, state_dict): file_info = state_dict[self.statename("file_info")] print(shard_states.size(0), self.worldsize) if self.rank == 0: - print(shard_states) + print(shard_states.shape, shard_states) if shard_states.size(0) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] From 8307e15b80967ceaaebb858183b9f6c7473e6bc1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 17:30:01 -0500 Subject: [PATCH 53/79] Diag save --- torchdata/stateful_dataloader/ibm_rescalable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 745d89268..642769459 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -492,6 +492,7 @@ def load_state_dict(self, state_dict): print(shard_states.size(0), self.worldsize) if self.rank == 0: print(shard_states.shape, shard_states) + torch.save(shard_states, "/gpfs/davis/test.pth") if shard_states.size(0) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] From 444547fe33a62ecaa9003b8feeda6f44e762d50e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 17:32:42 -0500 Subject: [PATCH 54/79] Diag save2 --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index e614c14f4..b722b8b39 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -125,7 +125,7 @@ if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) - time.sleep(10) + time.sleep(10000) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) From c94b4aec298bc5715007de4d75dd3160c1a93657 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 18:19:04 -0500 Subject: [PATCH 55/79] Flattenang --- torchdata/stateful_dataloader/ibm_rescalable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 642769459..aa255e75e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -479,7 +479,7 @@ def state_dict(self): # Values to save: shard states (shard/repl), filesizes (single/repl) # Deepcopy required to prevent in-place modification from later prefetches return deepcopy({ - self.statename("shard_states"): self.shard_states.unsqueeze(0), + self.statename("shard_states"): self.shard_states, self.statename("file_info"): self.filesizes if self.rank == 0 else None }) @@ -498,8 +498,8 @@ def load_state_dict(self, state_dict): self.shard_states = shard_states[self.rank] else: print("GOTHERE 3") - shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 - shard_states = torch.cat(shard_states, dim=0) # wn 5 + # shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 + # shard_states = torch.cat(shard_states, dim=0) # wn 5 # Sort shards by epoch count sorted, indices = torch.sort(shard_states[:,4], descending=True, stable=True) shard_states = shard_states[indices] From 53a89b51d857836e65f961ae9ff9517a78c0bb81 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 18:29:13 -0500 Subject: [PATCH 56/79] Flattenang 2 --- examples/ibm_rescaling/rescaling_demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index b722b8b39..51e043cb3 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -125,6 +125,7 @@ if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) + print("FINAL") time.sleep(10000) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() From ad72ca036abe5b1c3d2462b131a62ff9b0a82cc3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 18:34:33 -0500 Subject: [PATCH 57/79] Flattenang 3 --- torchdata/stateful_dataloader/ibm_rescalable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index aa255e75e..4d8d35116 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -633,6 +633,7 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Repeat global tensor over all workers + print(inp["dstate"]) dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): From c26767581d8fcccb14d2e50b8228699ddfdb4806 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 18:39:11 -0500 Subject: [PATCH 58/79] Diag print (sigh) --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 4d8d35116..6ba43383f 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -633,7 +633,7 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Repeat global tensor over all workers - print(inp["dstate"]) + print(inp["dstate"]["ScalableReader.shard_states"][:,0].tolist()) dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): From 03b4b3ae9fe8b35668abff1c27824a6166d824a2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Feb 2025 18:42:53 -0500 Subject: [PATCH 59/79] Diag print (sigh)2 --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 6ba43383f..ce4a2b79c 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -633,7 +633,7 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Repeat global tensor over all workers - print(inp["dstate"]["ScalableReader.shard_states"][:,0].tolist()) + print(inp["dstate"]["ScalableReader.shard_states"][:,0]) dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): From da5991b2033e45b287d96790cafb00af0823f53e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 15:23:12 -0500 Subject: [PATCH 60/79] Attempt key-free load impl --- .../stateful_dataloader/ibm_rescalable.py | 86 ++++++++++--------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ce4a2b79c..efd1ad887 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -97,9 +97,12 @@ def setup(self): self.worldsize = self.worldsize * self.local_worldsize self.rank = self.local_worldsize * self.rank + info.id - def statename(self, x: str): + def statename(self, x: str, rank=None): # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline - return self.__class__.__name__ + "." + x + out = self.__class__.__name__ + "." + x + if rank is not None: + out = "rank" + str(rank) + "." + out + return out def state_dict(self): """ @@ -478,29 +481,28 @@ def state_dict(self): self.setup() # Values to save: shard states (shard/repl), filesizes (single/repl) # Deepcopy required to prevent in-place modification from later prefetches - return deepcopy({ - self.statename("shard_states"): self.shard_states, - self.statename("file_info"): self.filesizes if self.rank == 0 else None - }) + out = {self.statename("shard_states", rank=self.rank): self.shard_states} + if self.rank==0: + out[self.statename("file_info")] = self.filesizes + return deepcopy(out) def load_state_dict(self, state_dict): - print("GOTHERE 2") self.setup() # Load back shard states (global), filesizes (all) shard_states = state_dict[self.statename("shard_states")] file_info = state_dict[self.statename("file_info")] - print(shard_states.size(0), self.worldsize) - if self.rank == 0: - print(shard_states.shape, shard_states) - torch.save(shard_states, "/gpfs/davis/test.pth") - if shard_states.size(0) == self.worldsize: + # print(shard_states.size(0), self.worldsize) + # if self.rank == 0: + # print(shard_states.shape, shard_states) + # torch.save(shard_states, "/gpfs/davis/test.pth") + if len(shard_states) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] else: - print("GOTHERE 3") # shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 # shard_states = torch.cat(shard_states, dim=0) # wn 5 # Sort shards by epoch count + shard_states = torch.cat(shard_states, dim=0) sorted, indices = torch.sort(shard_states[:,4], descending=True, stable=True) shard_states = shard_states[indices] # Strip out dummy shards @@ -545,7 +547,6 @@ def load_state_dict(self, state_dict): # Pad out with dummy shards if needed self.shard_states[len(shard_states):,0] = -1 self.shard_states[len(shard_states):,4] = torch.iinfo(torch.int).max - print("GOTHERE 4") return None @@ -559,20 +560,22 @@ def __pop_dstate(state, device_mesh, placements, create_dtensor=False): """ dstate = state["_snapshot"]["_worker_snapshots"] dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))] - # Flip list[dict[tensor]] to dict[list[tensor]], and concat - shardstate = "ScalableReader.shard_states" - fileinfo = "ScalableReader.file_info" - dstate_dict = { - shardstate: torch.cat([d[shardstate] for d in dstate], 0) - } - if create_dtensor == True: - dstate_dict[shardstate] = dtensor.DTensor.from_local( - dstate_dict[shardstate], - device_mesh, - placements, - ) - dstate_dict[fileinfo] = dstate[0][fileinfo] - return dstate_dict + # Fuse dstate dicts + return {k:v for d in dstate for k,v in d.items()} + # # Flip list[dict[tensor]] to dict[list[tensor]], and concat + # shardstate = "ScalableReader.shard_states" + # fileinfo = "ScalableReader.file_info" + # dstate_dict = { + # shardstate: torch.cat([d[shardstate] for d in dstate], 0) + # } + # if create_dtensor == True: + # dstate_dict[shardstate] = dtensor.DTensor.from_local( + # dstate_dict[shardstate], + # device_mesh, + # placements, + # ) + # dstate_dict[fileinfo] = dstate[0][fileinfo] + # return dstate_dict def save_distributed_state_dict( @@ -588,9 +591,9 @@ def save_distributed_state_dict( """ state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)], True) - # Prune empty fileinfos - if dstate["ScalableReader.file_info"] is None: - dstate.pop("ScalableReader.file_info") + # # Prune empty fileinfos + # if dstate["ScalableReader.file_info"] is None: + # dstate.pop("ScalableReader.file_info") out = {"state":state, "dstate":dstate} # Write distributed state dict writer = checkpoint.FileSystemWriter(path) @@ -618,13 +621,19 @@ def load_distributed_state_dict( inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) - checkpoint.load_state_dict( - inp, - reader, - ) + inp = checkpoint._load_state_dict_from_keys(storage_reader = reader) # NOTE: assumes inp["state"] is same across all devices + # checkpoint.load_state_dict( + # inp, + # reader, + # ) dstate = inp["dstate"] - # De-DTensor-fy the shard states - dstate["ScalableReader.shard_states"] = dstate["ScalableReader.shard_states"].full_tensor() + # Re-pack the set of rankX args + ranked_state = {k:dstate.pop(k) for k in dstate if "rank" in k} + ranked_keylist = sorted(list(ranked_state.keys())) + compiled_ranked = [ranked_state[k] for k in ranked_keylist] + dstate[ranked_keylist[0][6:]] = compiled_ranked # Drop "rank0." prefix + # # De-DTensor-fy the shard states + # dstate["ScalableReader.shard_states"] = dstate["ScalableReader.shard_states"].full_tensor() # Check that number of workers matches ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "loader" in x]) if ckp_ws == loader.dataset.worldsize and nworkers == state["_snapshot"]["_main_snapshot"]["_num_workers"]: @@ -633,11 +642,10 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Repeat global tensor over all workers - print(inp["dstate"]["ScalableReader.shard_states"][:,0]) + # print(inp["dstate"]["ScalableReader.shard_states"][:,0]) dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): state["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"] = dstate[i] # Load into loader - print("GOTHERE") loader.load_state_dict(state) From 903780088f63e2992bc36e73bd8aeab7a70d64d2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 15:34:31 -0500 Subject: [PATCH 61/79] Allow full run --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 51e043cb3..7ddfda512 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -125,8 +125,8 @@ if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) - print("FINAL") - time.sleep(10000) + # print("FINAL") + # time.sleep(10000) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) From 5f10ac181b65c62b12799c332f9fa58eec5749c9 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 15:43:47 -0500 Subject: [PATCH 62/79] Direct import --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index efd1ad887..a1979dab9 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -7,6 +7,7 @@ import torch from torch.distributed import checkpoint +from torch.distributed.checkpoint import _load_state_dict_from_keys import torch.distributed.tensor as dtensor import torch.distributed as dist import torch.utils.data as data @@ -621,7 +622,7 @@ def load_distributed_state_dict( inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) - inp = checkpoint._load_state_dict_from_keys(storage_reader = reader) # NOTE: assumes inp["state"] is same across all devices + inp = _load_state_dict_from_keys(storage_reader = reader) # NOTE: assumes inp["state"] is same across all devices # checkpoint.load_state_dict( # inp, # reader, From 89316201074ee59eb329a3f469af7d4d65b24bd2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 16:42:52 -0500 Subject: [PATCH 63/79] Precise import --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index a1979dab9..f726f10cd 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -7,7 +7,7 @@ import torch from torch.distributed import checkpoint -from torch.distributed.checkpoint import _load_state_dict_from_keys +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys import torch.distributed.tensor as dtensor import torch.distributed as dist import torch.utils.data as data From 3a6e2555cf2fc72add03b46609e2d1125bc391b5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 16:53:24 -0500 Subject: [PATCH 64/79] gloo backend --- examples/ibm_rescaling/rescaling_demo.py | 2 +- torchdata/stateful_dataloader/ibm_rescalable.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 7ddfda512..04a52e76f 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -44,7 +44,7 @@ # Setup rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) -dist.init_process_group() +dist.init_process_group(backend="gloo") mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) placement = [dist.tensor.placement_types.Shard(0)] diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index f726f10cd..ebda76d1e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -622,7 +622,9 @@ def load_distributed_state_dict( inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path) - inp = _load_state_dict_from_keys(storage_reader = reader) # NOTE: assumes inp["state"] is same across all devices + inp = _load_state_dict_from_keys( + storage_reader = reader, + ) # NOTE: assumes inp["state"] is same across all devices # checkpoint.load_state_dict( # inp, # reader, From ba969583d8898dc8a5e3874a2c3515e99550e52f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 16:56:49 -0500 Subject: [PATCH 65/79] Diag print --- torchdata/stateful_dataloader/ibm_rescalable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ebda76d1e..da272e58a 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -629,6 +629,7 @@ def load_distributed_state_dict( # inp, # reader, # ) + print(inp) dstate = inp["dstate"] # Re-pack the set of rankX args ranked_state = {k:dstate.pop(k) for k in dstate if "rank" in k} From 3ffb475f2897eb5ec460fa757f9b0870e50eeb61 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 17:00:16 -0500 Subject: [PATCH 66/79] Specify keys --- torchdata/stateful_dataloader/ibm_rescalable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index da272e58a..78a4538cf 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -623,6 +623,7 @@ def load_distributed_state_dict( # Read distributed state dict reader = checkpoint.FileSystemReader(path) inp = _load_state_dict_from_keys( + keys=set("state", "dstate"), storage_reader = reader, ) # NOTE: assumes inp["state"] is same across all devices # checkpoint.load_state_dict( From 95cf4947f9212b9e668fe48191ab500037a81e72 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 17:04:06 -0500 Subject: [PATCH 67/79] Set constructor --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 78a4538cf..1318f9b91 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -623,7 +623,7 @@ def load_distributed_state_dict( # Read distributed state dict reader = checkpoint.FileSystemReader(path) inp = _load_state_dict_from_keys( - keys=set("state", "dstate"), + keys=set(["state", "dstate"]), storage_reader = reader, ) # NOTE: assumes inp["state"] is same across all devices # checkpoint.load_state_dict( From 4a592b7379d1ef6511fc8e9908710114b7a0b16b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 17:07:46 -0500 Subject: [PATCH 68/79] Avoid popping keys mid iter --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 1318f9b91..06e726719 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -633,7 +633,8 @@ def load_distributed_state_dict( print(inp) dstate = inp["dstate"] # Re-pack the set of rankX args - ranked_state = {k:dstate.pop(k) for k in dstate if "rank" in k} + keys = list(dstate.keys()) + ranked_state = {k:dstate.pop(k) for k in keys if "rank" in k} ranked_keylist = sorted(list(ranked_state.keys())) compiled_ranked = [ranked_state[k] for k in ranked_keylist] dstate[ranked_keylist[0][6:]] = compiled_ranked # Drop "rank0." prefix From c37b8badb2bc1aa514d362c2aed9de94c9923054 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 17:31:04 -0500 Subject: [PATCH 69/79] Diag print --- torchdata/stateful_dataloader/ibm_rescalable.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 06e726719..6e7081727 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -538,10 +538,13 @@ def load_state_dict(self, state_dict): # Pull out shard allocation for this worker # (sort/reverse-sort ensures allocations are off by no more than 1) - shard_states = torch.cat([ + shards = [ completed_shards[self.rank], incomplete_shards[self.rank] - ]) + ] + if self.rank == 4: + print(shards) + shard_states = torch.cat(shards) # Order shards by global ID (for steady file progression) _, indices = shard_states[:,0].sort() self.shard_states[:len(shard_states)] = shard_states[indices] @@ -630,7 +633,6 @@ def load_distributed_state_dict( # inp, # reader, # ) - print(inp) dstate = inp["dstate"] # Re-pack the set of rankX args keys = list(dstate.keys()) From 0b09fd406cfc70ad69f7a37be4e65a8a49cf6562 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 17:41:36 -0500 Subject: [PATCH 70/79] diag print off --- examples/ibm_rescaling/rescaling_demo.py | 10 +++++----- torchdata/stateful_dataloader/ibm_rescalable.py | 4 ---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 04a52e76f..df6316215 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -129,11 +129,11 @@ # time.sleep(10000) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() - # Finish out epoch (extra 2*ceil(ndocs/nshards) steps to account for worst-case uneven finishing times) + # Finish out epoch (extra 2*ceil(n_items/n_shards) steps to account for worst-case uneven finishing times) vals = [] n_steps = ( math.ceil((3000 - len(avoid)) / (world_size * args.b_size)) - + 2 * math.ceil(1000/args.logical_shards) + + 2 * math.ceil(3000/args.logical_shards) ) for i, inp in enumerate(data): if i == n_steps: @@ -146,9 +146,9 @@ # Diag save os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True) torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) - if rank == 0: - torch.save(vals, os.path.join(args.ckpt_path, "diag", "vals.pth")) - time.sleep(10) + # if rank == 0: + # torch.save(vals, os.path.join(args.ckpt_path, "diag", "vals.pth")) + # time.sleep(10) # Perform data coverage check on rank 0 only if rank == 0: diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 6e7081727..3e919c336 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -534,16 +534,12 @@ def load_state_dict(self, state_dict): # Reverse sort incomplete shards by length incomplete_shards.sort(key=len, reverse=True) - print("shardlen", [len(x) for x in completed_shards], [len(x) for x in incomplete_shards]) - # Pull out shard allocation for this worker # (sort/reverse-sort ensures allocations are off by no more than 1) shards = [ completed_shards[self.rank], incomplete_shards[self.rank] ] - if self.rank == 4: - print(shards) shard_states = torch.cat(shards) # Order shards by global ID (for steady file progression) _, indices = shard_states[:,0].sort() From 71b78dcefc78e76abe460019a43a2d9d6b223643 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Feb 2025 17:10:29 -0500 Subject: [PATCH 71/79] Clean up and comment out --- examples/ibm_rescaling/rescaling_demo.py | 21 +-- .../stateful_dataloader/ibm_rescalable.py | 152 ++++++++++-------- 2 files changed, 91 insertions(+), 82 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index df6316215..6440660e1 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -15,7 +15,7 @@ save_distributed_state_dict, ) -# This example script validates the rescaling behavior of the ibm rescalable distributed datasets. +# This example script validates the rescaling behavior of the ScalableReader. # On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location. # On subsequent runs, loads the checkpoint (possibly on a different world size / num workers) # and verifies that all remaining data is covered by the time the epoch finishes. @@ -23,7 +23,7 @@ # Example usage: # torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6 -# Do not change the batch size or number of steps between the first and second runs! +# Do not change the number of steps between the first and second runs! parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints") parser.add_argument("--ckpt_path", type=str, default="./rescale_test") @@ -85,10 +85,6 @@ # Wrap in StatefulDataLoader data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) -# TODO: debug: can't change n_workers when reloading - keyerror -# TODO: debug: going from 4/2/2 gpu/bsize/workers to 6/1/2 causes epoch not to finish - - # If checkpoint does not exist, create it ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state") if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0: @@ -125,30 +121,25 @@ if rank == 0: print("Checkpoint detected!") load_distributed_state_dict(data, ckpt_path, mesh) - # print("FINAL") - # time.sleep(10000) avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist() - # Finish out epoch (extra 2*ceil(n_items/n_shards) steps to account for worst-case uneven finishing times) + # Finish out epoch (extra 2*ceil(ceil(n_items/n_shards)/bsize) steps to account for worst-case uneven finishing times) vals = [] n_steps = ( math.ceil((3000 - len(avoid)) / (world_size * args.b_size)) - + 2 * math.ceil(3000/args.logical_shards) + + 2 * math.ceil(math.ceil(3000/args.logical_shards)/args.b_size) ) for i, inp in enumerate(data): + vals.append(inp) if i == n_steps: break - vals.append(inp) vals = torch.cat(vals) # Get all vals onto each rank vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() - # Diag save + # Save final state dicts for diagnostic purposes os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True) torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth")) - # if rank == 0: - # torch.save(vals, os.path.join(args.ckpt_path, "diag", "vals.pth")) - # time.sleep(10) # Perform data coverage check on rank 0 only if rank == 0: diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 3e919c336..71627565e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -15,37 +15,40 @@ from .stateful_dataloader import StatefulDataLoader """ -TODO: UPDATE THIS FOR SCALABLEREADER - -The following distributed dataloaders are designed around 3 main principles: - -1. Efficient, asynchronous operation. Workers on different devices do not communicate. -2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator - loading from disk and additional layers adding levels of post-processing (shuffling, - packing, padding, rescaling, etc.). -3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal - state that can be written/read on disk via implemented recursive `state_dict()` and - `load_state_dict()` calls. Any values that should be saved to state can be designated - 'state_params' and will be automatically included in the state dict. States must be - valid targets of torch.tensor(). -4. Rescalability. Users can save and load checkpoints to/from different numbers of workers - without losing the global state. This is accomplished by splitting the global state over - a predefined large number of small partitions, each of which tracks its own individual - state. Rescaling is accomplished by re-distributing these shards over the physical workers. - -Our loaders obey the following type hierarchy: -torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. -`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a -single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, -then applying some sort of post-processing and yielding the result. Users build data processing -pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers, -which is then passed to the torch DataLoader. - -It is likely that this can be merged into the existing Nodes structure, but we leave this for -future work, for now. +This file borrows the StatefulDataset framework from the IBM fms-fsdp repo to implement rescalable data +loading. This framework is analogous to the existing torchdata nodes framework and will be converted +in the future. + +Rescalability is implemented at the base level - you must use this layer to interface with a collection +of indexable files directly. The ScalableReader then yields data values like an iterator. These values +are not shuffled. + +ScalableReader interfaces with indexable files via custom FileHandlers. These FileHandlers implement basic +file operations such as file type checking, opening, indexing, and slicing. By implementing these basic +operations, users can add support for arbitrary file types. + +Rescalability is implemented by splitting data into a large number of logical shards, which are then +allocated over the set of dataloader workers. We assume that logical shards vastly outnumber workers, +such that when workers do not divide logical shards evenly, the off-by-one allocations don't matter and +workers still finish their epochs at roughly the same time. Files are assigned to logical shards +fractionally and based on file size, such that each shard contains roughly equal amounts of data, and as +few individual files as possible. This minimizes the number of file pulls. + +ScalableReaders step through a single active logical shard at a time, to minimize overhead. This behavior +can be relaxed later. + +When rescaling to a different number of workers, the logical shard progress counters are aggregated +globally onto each ScalableReader. Then, completed and incomplete logical shards are re-allocated +separately, to ensure that each worker receives roughly the same ratio of seen to unseen data in the +current epoch. This allows us to scale from any number of workers to any other number. + +State dicts must be saved using DCP in current code, but this can also be relaxed in future for cases when +rescaling is not required. Rescaling will always require DCP. """ +#### ------------------------- BORROWED FROM IBM FMS-FSDP ------------------------- #### + class _StatefulDataset(data.IterableDataset): """ Stub for stateful datasets, extends data.IterableDataset with state_dict methods. @@ -177,7 +180,7 @@ def state_dict(self): return out -#### ------------------------- FILE READERS ------------------------- #### +#### ------------------------- FILE HANDLERS ------------------------- #### class _ShardFileHandler: @@ -296,10 +299,18 @@ def __iter__(self): yield self.aug_fn(out) +#### ------------------------- NEW CODE STARTS HERE ------------------------- #### + + class ScalableReader(_StatefulDataset): """ - Maintains shared logical shards but opens them one at a time. Completely repartitions - unseen shards only when rescaling. + Maintains n x 5 state buffer where n is the number of logical shards owned by this worker, + and 5 is the number of relevant data fields per-shard. Finishes shards with the lowest + visit count before continuing into new epoch. When rescaling, re-allocates visited / unvisited + shards in the current epoch separately, so that each new worker finishes the epoch at around + the same time. + + Currently does not shuffle docs within shards/files, but this can be added later. """ def __init__( @@ -318,17 +329,17 @@ def __init__( verbose: bool = False, ): super().__init__(datapath, rank, worldsize) - self.seed = seed + self.seed = seed # Currently unused self.datapath = datapath self.filehandler = filehandler() - self.min_length = min_length + self.min_length = min_length # Ignore any docs shorter than this assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" - self.chunksize = max_chunksize - self.eos = delimiter_token - self.bos = bos_token - self.drop = strip_tokens + self.chunksize = max_chunksize # Yield chunks at a time if doc is longer than this + self.eos = delimiter_token # Inserted between each doc + self.bos = bos_token # Inserted before each doc (optional) + self.drop = strip_tokens # Tokens to drop from begin/end of doc (replaced by above delimiter/bos) self.n_logical_shards = n_logical_shards - self.verbose = verbose + self.verbose = verbose # Currently unused # Position self.reader = None @@ -336,21 +347,24 @@ def __init__( # Setup flags self.is_setup = False - self.filesizes = None # [[filenames], [filesizes]] CONSTRUCTED PRE ITER IF NOT LOADED - self.shard_states = None # shardid, file pos, doc pos, chunk pos, epoch RESHARD + self.filesizes = None # [[filenames], [filesizes]] (constructed pre-iter if not loaded from ckp) + self.shard_states = None # shardid, file pos, doc pos, chunk pos, epoch (reshardable state buffer) # TODO: add handling to prevent zero-length allocations def _get_shard_breakdown(self, rank, nshards): - # Find highest fileid still smaller than start + """ + Retrieve the set of (fractional) files assigned to a given logical shard + """ + # Find first doc included in the current shard sizelist = torch.tensor(self.filesizes[1]) - sizelist = sizelist/sizelist.float().mean() + sizelist = sizelist/sizelist.float().sum() cum_sizelist = sizelist.cumsum(0) - start_frac = rank/nshards*len(sizelist) + start_frac = rank/nshards start_id = len(sizelist) - cum_sizelist.gt(start_frac).sum().item() # For each doc, assign relevant fractional ownership start = start_frac - end = (rank+1)/nshards*len(sizelist) + end = (rank+1)/nshards my_files = [] # fileid, start%, end% for i, (size, cumsize_incl) in enumerate( zip(sizelist[start_id:].tolist(), cum_sizelist[start_id:].tolist()) @@ -358,6 +372,7 @@ def _get_shard_breakdown(self, rank, nshards): id = start_id + i cumsize = cumsize_incl - size if cumsize > end: + # No more files to include, stop early break elif cumsize <= end and cumsize_incl >= start: my_files.append([ @@ -368,6 +383,10 @@ def _get_shard_breakdown(self, rank, nshards): return my_files def setup(self): + """ + Perform any rank-dependent setup. This operation is deferred from __init__ to support + multiple workers in the dataloader. + """ if not self.is_setup: # Get your adjusted rank and worldsize super().setup() @@ -381,12 +400,16 @@ def setup(self): # Set up logical shard states (may be overwritten later by ckp load) self.shard_states = torch.zeros(math.ceil(self.n_logical_shards / self.worldsize), 5, dtype=torch.int) self.shard_states[:len(my_shards), 0] = torch.tensor(my_shards) + + # Pad shard state if this worker is off by one. Id is -1 and visit count is inf. self.shard_states[len(my_shards):, 0] = -1 self.shard_states[len(my_shards):, 4] = torch.iinfo(torch.int).max def _pre_iter(self): - # Run after loading checkpoint, before iterating - + """ + Construct index of data files and their filesizes. + This is saved/loaded in subsequent checkpoints to avoid re-indexing the entire dataset repeatedly. + """ # Assemble set of available shard files, if nonexistant if self.filesizes is None: # Find all legal files @@ -480,7 +503,7 @@ def __iter__(self): def state_dict(self): self.setup() - # Values to save: shard states (shard/repl), filesizes (single/repl) + # Values to save: shard states, filesizes # Deepcopy required to prevent in-place modification from later prefetches out = {self.statename("shard_states", rank=self.rank): self.shard_states} if self.rank==0: @@ -489,24 +512,18 @@ def state_dict(self): def load_state_dict(self, state_dict): self.setup() - # Load back shard states (global), filesizes (all) - shard_states = state_dict[self.statename("shard_states")] + # Load back shard states and file sizes + shard_states = state_dict[self.statename("shard_states")] # list[tensor] file_info = state_dict[self.statename("file_info")] - # print(shard_states.size(0), self.worldsize) - # if self.rank == 0: - # print(shard_states.shape, shard_states) - # torch.save(shard_states, "/gpfs/davis/test.pth") if len(shard_states) == self.worldsize: self.filesizes = file_info self.shard_states = shard_states[self.rank] else: - # shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5 - # shard_states = torch.cat(shard_states, dim=0) # wn 5 # Sort shards by epoch count shard_states = torch.cat(shard_states, dim=0) sorted, indices = torch.sort(shard_states[:,4], descending=True, stable=True) shard_states = shard_states[indices] - # Strip out dummy shards + # Strip out dummy padding shards n_dummies = sorted.eq(torch.iinfo(torch.int).max).sum() shard_states = shard_states[n_dummies:] # n_logical 5 assert len(shard_states) == self.n_logical_shards, f"Number of shards {len(shard_states)} does not match specified {self.n_logical_shards}" @@ -532,6 +549,7 @@ def load_state_dict(self, state_dict): ] for i in range(self.worldsize) ] # Reverse sort incomplete shards by length + # Minimizes padding by overallocating incomplete shards to underallocated complete shards incomplete_shards.sort(key=len, reverse=True) # Pull out shard allocation for this worker @@ -555,8 +573,10 @@ def load_state_dict(self, state_dict): def __pop_dstate(state, device_mesh, placements, create_dtensor=False): """ - Removes worker states from the StatefulDataLoader state dict, and assembles them - into a separate list of dicts of dtensors for distributed checkpointing. + Removes worker states from the StatefulDataLoader state dict, and fuses them into a single dict + (assuming no key overlap, which we currently guarantee by adding a rank to each worker's shardstate) + Includes old dtensor logic but currently not used (as no state buffers are getting resharded + straightforwardly). This will likely change in the future. """ dstate = state["_snapshot"]["_worker_snapshots"] dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))] @@ -586,8 +606,7 @@ def save_distributed_state_dict( """ Retrieves dataloader state dict, and separates worker states from loader state. Loader state is not rescalable, and is discarded when rescaling. - Rescalable worker states are compiled into a dtensor across ranks, and saved - using pytorch distributed checkpointing. + Saves dict using DCP. """ state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)], True) @@ -609,11 +628,14 @@ def load_distributed_state_dict( device_mesh: dist.DeviceMesh, ): """ - Retrieves dataloader state dict, and separates worker states from loader state. + Retrieves dataloader state dict using DCP, and separates worker states from loader state. If not rescaling, load saved dataloader state. - Rescalable worker states are retrieved using pytorch distributed checkpointing. States are replicated over workers, and ScalableReader will handle partitioning and re-assignment of available states into logical ranks. + + Loading back to the same number of workers results in key overlap for 'state', so I suspect + that any rank-dependent dataloader state is being lost or overwritten in this case. + TODO: verify/fix """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] @@ -625,12 +647,9 @@ def load_distributed_state_dict( keys=set(["state", "dstate"]), storage_reader = reader, ) # NOTE: assumes inp["state"] is same across all devices - # checkpoint.load_state_dict( - # inp, - # reader, - # ) dstate = inp["dstate"] # Re-pack the set of rankX args + # NOTE: this is the step currently breaking the no-DCP path keys = list(dstate.keys()) ranked_state = {k:dstate.pop(k) for k in keys if "rank" in k} ranked_keylist = sorted(list(ranked_state.keys())) @@ -646,7 +665,6 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Repeat global tensor over all workers - # print(inp["dstate"]["ScalableReader.shard_states"][:,0]) dstate = [inp["dstate"],]*nworkers # Re-insert worker states into loader state for i in range(nworkers): From baf9c13537f9c65dcb350b585e2b1c7243f31510 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 5 Mar 2025 13:42:22 -0500 Subject: [PATCH 72/79] Refactor -ibm --- examples/{ibm_rescaling => data_rescaling}/rescaling_demo.py | 2 +- .../{ibm_rescalable.py => scalable_reader.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename examples/{ibm_rescaling => data_rescaling}/rescaling_demo.py (99%) rename torchdata/stateful_dataloader/{ibm_rescalable.py => scalable_reader.py} (100%) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/data_rescaling/rescaling_demo.py similarity index 99% rename from examples/ibm_rescaling/rescaling_demo.py rename to examples/data_rescaling/rescaling_demo.py index 6440660e1..69a6f434b 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/data_rescaling/rescaling_demo.py @@ -7,7 +7,7 @@ from torch import distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.ibm_rescalable import ( +from torchdata.stateful_dataloader.scalable_reader import ( ArrowHandler, PreprocessDataset, ScalableReader, diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/scalable_reader.py similarity index 100% rename from torchdata/stateful_dataloader/ibm_rescalable.py rename to torchdata/stateful_dataloader/scalable_reader.py From 88d993f3475da2d59ac78ca2506be6494f14afc1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 5 Mar 2025 13:52:09 -0500 Subject: [PATCH 73/79] abc shard handler --- .../stateful_dataloader/scalable_reader.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchdata/stateful_dataloader/scalable_reader.py b/torchdata/stateful_dataloader/scalable_reader.py index 71627565e..5658bd990 100644 --- a/torchdata/stateful_dataloader/scalable_reader.py +++ b/torchdata/stateful_dataloader/scalable_reader.py @@ -2,6 +2,7 @@ import math import os import pyarrow as pa +from abc import ABCMeta, abstractmethod from copy import deepcopy from typing import Any, Callable, List, Optional, Set @@ -183,7 +184,7 @@ def state_dict(self): #### ------------------------- FILE HANDLERS ------------------------- #### -class _ShardFileHandler: +class ShardFileHandler(object, metaclass=ABCMeta): """ Stub for shard file readers of different formats. Must implement open, length, indexing, and slicing functions. @@ -196,20 +197,23 @@ def is_legal(self, filepath: str): """ return os.path.isfile(filepath) + @abstractmethod def open(self, path: str): """ Open the file, to be indexed via self.get() method. Avoid reading entire multi-Gb files when possible! """ - raise NotImplementedError + pass + @abstractmethod def length(self, path: str): """ Calculate the number of documents in the given file. Avoid reading entire multi-Gb files when possible! """ - raise NotImplementedError + pass + @abstractmethod def get(self, reader, index: int, drop_tokens: Set): """ Given the output of self.open() and an index, return the document at that index. @@ -218,8 +222,9 @@ def get(self, reader, index: int, drop_tokens: Set): but this is less important than avoiding reading entire files as above. Output must support len() method. """ - raise NotImplementedError + pass + @abstractmethod def slice(self, doc, index: int, n_pull: int) -> List: """ Given a long document, retrieve n_pull consecutive items starting from index. @@ -227,10 +232,10 @@ def slice(self, doc, index: int, n_pull: int) -> List: and self.open() is far more important. Must return a python list. """ - raise NotImplementedError + pass -class ArrowHandler(_ShardFileHandler): +class ArrowHandler(ShardFileHandler): """ Reader for indexable, pre-tokenized PyArrow shard files. Pyarrow shard files are expected to hold multiple RecordBatches, @@ -318,7 +323,7 @@ def __init__( datapath: str, rank: int, worldsize: int, - filehandler: _ShardFileHandler, + filehandler: ShardFileHandler, delimiter_token: Any, bos_token: Optional[Any] = None, strip_tokens: Optional[Set[Any]] = set(), From 21db51620651a25baa1fa35560432b52d4899a38 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 5 Mar 2025 13:53:06 -0500 Subject: [PATCH 74/79] Refactor wrapperdataset --- torchdata/stateful_dataloader/scalable_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdata/stateful_dataloader/scalable_reader.py b/torchdata/stateful_dataloader/scalable_reader.py index 5658bd990..07f0a10c3 100644 --- a/torchdata/stateful_dataloader/scalable_reader.py +++ b/torchdata/stateful_dataloader/scalable_reader.py @@ -125,7 +125,7 @@ def load_state_dict(self, state_dict): [setattr(self, flag, state_dict[self.statename(flag)]) for flag in self.state_params] -class _WrapperDataset(_StatefulDataset): +class _NestedStatefulDataset(_StatefulDataset): """ Stub for nested wrappers of _StatefulDatasets. Extends state fns with recursion. Requires a single instantiated sub-dataset (which may be replicated during setup fn). @@ -276,7 +276,7 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: #### ------------------------- DATASET LAYERS ------------------------- #### -class PreprocessDataset(_WrapperDataset): +class PreprocessDataset(_NestedStatefulDataset): """ Wrapper for a _StatefulDataset that applies a specified preprocessing or augmentation function to dataset outputs. From 99fb2af8a9b7eec06e2bb6ad708e6d21076a3f64 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 14 Mar 2025 13:07:12 -0400 Subject: [PATCH 75/79] First draft unit tests --- test/scalable_reader/test_scalable_reader.py | 142 +++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 test/scalable_reader/test_scalable_reader.py diff --git a/test/scalable_reader/test_scalable_reader.py b/test/scalable_reader/test_scalable_reader.py new file mode 100644 index 000000000..b391e9d0b --- /dev/null +++ b/test/scalable_reader/test_scalable_reader.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import pyarrow as pa +import unittest + +import torch + +from torch.testing._internal.common_utils import TestCase +from .._utils._common_utils_for_test import create_temp_dir + +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.scalable_reader import ScalableReader, PreprocessDataset, ArrowHandler + +# A set of draft unit tests for the ScalableReader. +# Note that these have not been locally tested or debugged yet (fighting my local environment), +# and likely fail in horrible ways. Mostly here for discussion/reference at this stage. + +# TODO: test actual save/load distributed functions via multiprocessing + +class TestScalableReader(TestCase): + def setUp(self): + super().setUp() + datapath = create_temp_dir() + schema = pa.schema([pa.field("tokens", pa.uint32())]) + with pa.ipc.new_file( + os.path.join(datapath, "fileshard_1.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(i * 100, i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + os.makedirs(os.path.join(datapath, "subfolder")) + with pa.ipc.new_file( + os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema + ) as writer: + for i in range(500): + out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) + writer.write(pa.record_batch([out], schema=schema)) + self.datapath = datapath + + def create_scalable( + self, + rank = 0, + worldsize = 1, + delimiter = -1, + bos = None, + seed = 42, + chunk = 1000, + logicals = 10 + ): + # Build dataloader + data = ScalableReader( + self.datapath, + rank, + worldsize, + ArrowHandler, + delimiter, + bos, + seed=seed, + max_chunksize=chunk, + n_logical_shards=logicals, + ) + # Pad entries to make them batch-able + data = PreprocessDataset(data, lambda x: x + [-1]*(chunk-len(x))) + # Statelessly convert all outputs to tensors + data = PreprocessDataset(data, torch.tensor) + return data + + def test_single_epoch(self): + for ws in [2,3,7]: + for nw in [0,2,3]: + loaderset = [iter(StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw)) for i in range(ws)] + n_steps = math.ceil(1000/ws)+10 + pools = [set() for _ in loaderset] + for _ in range(n_steps): + for i,l in enumerate(loaderset): + pools[i].add(next(l)[0,0].item()) + for i in range(len(pools)): + for j in range(i+1, len(pools)): + print(f"Checking outputs {i} and {j}") + overlap = len(pools[i].intersection(pools[j])) + self.assertEqual(overlap, 0, f"Overlapping data found in workers {i} and {j} (worldsize {ws}/{ws*max(nw,1)}): {overlap}") + alldata = set.union(*pools) + expected = set([x*100 for x in range(1000)]) + missing = len(expected.difference(alldata)) + self.assertEqual(missing, 0, f"Missing data from pool: {missing}") + + def test_resumption(self): + for ws in [2,3,7]: + for nw in [0,2,3]: + loaderset = [StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw) for i in range(ws)] + loaderset2 = [StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw) for i in range(ws)] + n_steps = 2*math.ceil(1000/ws) # Proceed well into second epoch + iterset = [iter(l) for l in loaderset] + for _ in range(100): + [next(l) for l in iterset] + for i in range(ws): + loaderset2[i].load_state_dict(loaderset[i].state_dict()) + iterset2 = [iter(l) for l in loaderset2] + for s in range(n_steps): + for i in range(ws): + expected = next(iterset[i]) + query = next(iterset2[i]) + self.assertEqual(expected, query, f"Mismatch at step 100+{s} rank {i}, (worldsize {ws}/{ws*max(nw,1)}): original {expected[0,:5]}..., recieved {query[0,:5]}") + + def test_rescale_epoch(self): + nsteps = 30 + for start_ws in [1,2,6]: + for end_ws in [3,4]: + for logicals in [300, 555, 721]: + # Create checkpoint + avoid = [] + data = StatefulDataLoader(self.create_scalable(logicals=logicals, chunk=40), num_workers=start_ws, batch_size=1) + for i, inp in enumerate(data): + avoid.append(inp[0,0].item()) + if i==(nsteps-1)*start_ws: + sd = data.state_dict() + break + # Load checkpoint + # (this step likely fails without using the custom distributed save/load checkpointing fns) + data = StatefulDataLoader(self.create_scalable(logicals=logicals, chunk=40), num_workers=end_ws, batch_size=1) + data.load_state_dict(sd) + vals = [] + nsteps = math.ceil(3000 - len(avoid)) + (2*math.ceil(3000/logicals)*end_ws) + for i, inp in enumerate(data): + vals.append(inp[0,0].item()) + if i == nsteps: + break + # Invert set of seen values + expect = [] + for i in range(1000): + for offset in [0,40,80]: + if i*100+offset not in avoid: + expect.append(i*100+offset) + for x in expect: + self.assertObjectIn(x, vals) + From a879ce1705e3873d7f5cd8f41dcc568f6f402f9b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Mar 2025 16:41:04 -0400 Subject: [PATCH 76/79] No direct import --- test/scalable_reader/test_scalable_reader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/scalable_reader/test_scalable_reader.py b/test/scalable_reader/test_scalable_reader.py index b391e9d0b..309eb359c 100644 --- a/test/scalable_reader/test_scalable_reader.py +++ b/test/scalable_reader/test_scalable_reader.py @@ -7,12 +7,12 @@ import math import os import pyarrow as pa +import tempfile import unittest import torch from torch.testing._internal.common_utils import TestCase -from .._utils._common_utils_for_test import create_temp_dir from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.scalable_reader import ScalableReader, PreprocessDataset, ArrowHandler @@ -23,6 +23,12 @@ # TODO: test actual save/load distributed functions via multiprocessing +def create_temp_dir(dir=None): + # The temp dir and files within it will be released and deleted in tearDown(). + # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. + temp_dir = tempfile.TemporaryDirectory(dir=dir) # noqa: P201 + return temp_dir + class TestScalableReader(TestCase): def setUp(self): super().setUp() From 31745ad5977858dee4b1632d00012a0d6a616324 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Mar 2025 17:08:13 -0400 Subject: [PATCH 77/79] name --- test/scalable_reader/test_scalable_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/scalable_reader/test_scalable_reader.py b/test/scalable_reader/test_scalable_reader.py index 309eb359c..4012e1565 100644 --- a/test/scalable_reader/test_scalable_reader.py +++ b/test/scalable_reader/test_scalable_reader.py @@ -47,7 +47,7 @@ def setUp(self): for i in range(500): out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) writer.write(pa.record_batch([out], schema=schema)) - self.datapath = datapath + self.datapath = datapath.name def create_scalable( self, From c16a5e0925f7b2d436e915af80b83e25e7ecd3d7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Mar 2025 17:09:46 -0400 Subject: [PATCH 78/79] Separate name and data --- test/scalable_reader/test_scalable_reader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/scalable_reader/test_scalable_reader.py b/test/scalable_reader/test_scalable_reader.py index 4012e1565..9da380e20 100644 --- a/test/scalable_reader/test_scalable_reader.py +++ b/test/scalable_reader/test_scalable_reader.py @@ -32,7 +32,8 @@ def create_temp_dir(dir=None): class TestScalableReader(TestCase): def setUp(self): super().setUp() - datapath = create_temp_dir() + data = create_temp_dir() + datapath = data.name schema = pa.schema([pa.field("tokens", pa.uint32())]) with pa.ipc.new_file( os.path.join(datapath, "fileshard_1.arrow"), schema @@ -48,6 +49,7 @@ def setUp(self): out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) writer.write(pa.record_batch([out], schema=schema)) self.datapath = datapath.name + self.data = data def create_scalable( self, From 1acb3be19f995222fa5b8c4896c78fbdfd9907c4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Mar 2025 17:10:22 -0400 Subject: [PATCH 79/79] separate name and data p2 --- test/scalable_reader/test_scalable_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/scalable_reader/test_scalable_reader.py b/test/scalable_reader/test_scalable_reader.py index 9da380e20..9b442cacf 100644 --- a/test/scalable_reader/test_scalable_reader.py +++ b/test/scalable_reader/test_scalable_reader.py @@ -48,7 +48,7 @@ def setUp(self): for i in range(500): out = list(range(50000 + i * 100, 50000 + i * 100 + 100)) writer.write(pa.record_batch([out], schema=schema)) - self.datapath = datapath.name + self.datapath = datapath self.data = data def create_scalable(