Skip to content

Rescalability layer #1455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 80 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
45b0bce
Add distributed datasets
daviswer Nov 22, 2024
e486614
Formatting, commenting
daviswer Nov 22, 2024
10e45b9
Add demo script
daviswer Nov 23, 2024
10a6f66
Datapath None
daviswer Nov 23, 2024
0281897
Shift dummydata seeding to setup, dummy path handling
daviswer Nov 23, 2024
a175c3c
Actually create dummy data folders
daviswer Nov 23, 2024
957a5bf
Remove cfg ref
daviswer Nov 23, 2024
2e9bdf0
Remove double () call
daviswer Nov 23, 2024
e475eec
Fix dist checkpoint import
daviswer Nov 23, 2024
eac8ef6
Check ckp subfolder existence, not working folder
daviswer Nov 23, 2024
afd0169
Save vals for checking
daviswer Nov 23, 2024
031d67c
Load dummy gen state always
daviswer Nov 23, 2024
d9a575b
Setup calls in dummy
daviswer Nov 23, 2024
157f90b
Diag print
daviswer Nov 23, 2024
91f1b14
Remove sampling
daviswer Nov 23, 2024
b3569e3
Path in dummy build
daviswer Nov 23, 2024
0faea8c
Path in dummy build
daviswer Nov 23, 2024
0be44e4
Scalable off
daviswer Nov 23, 2024
c54aed2
Build data folder early
daviswer Nov 23, 2024
a16ffb1
Avoid resetting gen each state dict call
daviswer Nov 23, 2024
b645aea
Diag print off, all datasets on
daviswer Nov 23, 2024
ceffd24
Stop saving vals
daviswer Nov 23, 2024
d2eb12e
Attempt single blob save
daviswer Jan 14, 2025
ada91ec
Attempt single blob load
daviswer Jan 14, 2025
9bf8f3d
Prevent loading in place
daviswer Jan 14, 2025
934d37b
Cleanup
daviswer Jan 14, 2025
8d0cfd8
ScalableReader changes
daviswer Feb 6, 2025
e633e60
Fix datapath folder creation
daviswer Feb 6, 2025
1f2e37a
Create datapath subfolder, data only when nonexistent
daviswer Feb 6, 2025
0acdf05
Build data only rank 0
daviswer Feb 6, 2025
d146017
Pad chunks to make batchable
daviswer Feb 6, 2025
0fd38e8
give time for data to construct
daviswer Feb 6, 2025
e000b81
Fix pad fn
daviswer Feb 6, 2025
5bbd0d1
reader yield list not tensor
daviswer Feb 6, 2025
888bc19
No arg for repl placement
daviswer Feb 6, 2025
9c1699d
typo fix
daviswer Feb 6, 2025
c551a07
De-dtensorfy in load
daviswer Feb 6, 2025
4675681
Full tensor (apparently replicated doesn't force on load)
daviswer Feb 6, 2025
65744ac
Shard load, full tensor sendaround
daviswer Feb 6, 2025
88ab3c7
Chunksize 40
daviswer Feb 6, 2025
a34a5fc
Intermediate diag mkdir
daviswer Feb 6, 2025
763f60e
Time for other ranks to save
daviswer Feb 6, 2025
476c5a6
exist ok diag subf
daviswer Feb 6, 2025
ba00c20
Corrected step counting
daviswer Feb 6, 2025
0fd2b15
Fix followup nstep scaling
daviswer Feb 10, 2025
fcfee89
diag print
daviswer Feb 10, 2025
57164ca
diag print2
daviswer Feb 10, 2025
068ab32
diag print3
daviswer Feb 10, 2025
dd7d569
diag print4
daviswer Feb 10, 2025
7fa868f
diag print5
daviswer Feb 10, 2025
473e9ff
diag print6
daviswer Feb 10, 2025
bf22ce9
diag print7
daviswer Feb 10, 2025
8307e15
Diag save
daviswer Feb 10, 2025
444547f
Diag save2
daviswer Feb 10, 2025
c94b4ae
Flattenang
daviswer Feb 10, 2025
53a89b5
Flattenang 2
daviswer Feb 10, 2025
ad72ca0
Flattenang 3
daviswer Feb 10, 2025
c267675
Diag print (sigh)
daviswer Feb 10, 2025
03b4b3a
Diag print (sigh)2
daviswer Feb 10, 2025
da5991b
Attempt key-free load impl
daviswer Feb 19, 2025
9037800
Allow full run
daviswer Feb 19, 2025
5f10ac1
Direct import
daviswer Feb 19, 2025
8931620
Precise import
daviswer Feb 19, 2025
3a6e255
gloo backend
daviswer Feb 19, 2025
ba96958
Diag print
daviswer Feb 19, 2025
3ffb475
Specify keys
daviswer Feb 19, 2025
95cf494
Set constructor
daviswer Feb 19, 2025
4a592b7
Avoid popping keys mid iter
daviswer Feb 19, 2025
c37b8ba
Diag print
daviswer Feb 19, 2025
0b09fd4
diag print off
daviswer Feb 19, 2025
71b78dc
Clean up and comment out
daviswer Feb 25, 2025
e2b35aa
Merge branch 'main' into loader-dcp
daviswer Feb 25, 2025
baf9c13
Refactor -ibm
daviswer Mar 5, 2025
88d993f
abc shard handler
daviswer Mar 5, 2025
21db516
Refactor wrapperdataset
daviswer Mar 5, 2025
99fb2af
First draft unit tests
daviswer Mar 14, 2025
a879ce1
No direct import
daviswer Mar 24, 2025
31745ad
name
daviswer Mar 24, 2025
c16a5e0
Separate name and data
daviswer Mar 24, 2025
1acb3be
separate name and data p2
daviswer Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions examples/data_rescaling/rescaling_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import argparse
import math
import os
import pyarrow as pa
import time
import torch
from torch import distributed as dist

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.scalable_reader import (
ArrowHandler,
PreprocessDataset,
ScalableReader,
load_distributed_state_dict,
save_distributed_state_dict,
)

# This example script validates the rescaling behavior of the ScalableReader.
# On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location.
# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers)
# and verifies that all remaining data is covered by the time the epoch finishes.

# Example usage:
# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6

# Do not change the number of steps between the first and second runs!

parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints")
parser.add_argument("--ckpt_path", type=str, default="./rescale_test")
parser.add_argument(
"--logical_shards",
type=int,
default=350,
help="Total number of data partitions. Must exceed (worldsize * n_workers) but not n_docs (1000).",
)
parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device")
parser.add_argument("--b_size", type=int, default=2, help="Number of data points per step per device")
parser.add_argument("--n_steps", type=int, default=50, help="Number of steps to take before saving. (n_steps * b_size * worldsize) cannot exceed number of items in epoch (3000)")
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()


# Setup
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
dist.init_process_group(backend="gloo")
mesh = dist.device_mesh.init_device_mesh("cpu", [world_size])
placement = [dist.tensor.placement_types.Shard(0)]

# Check input args
assert args.logical_shards >= world_size*args.num_workers, f"Logical shards {args.logical_shards} cannot be less than total workers {world_size*args.num_workers}"
assert args.logical_shards <= 1000, f"Logical shards {args.logical_shards} cannot exceed number of documents 1000"
assert args.n_steps*args.b_size*world_size < 3000, f"Number of items drawn before saving {args.n_steps*args.b_size*world_size} cannot exceed number of document chunks 3000."

# Build dataset
datapath = os.path.join(args.ckpt_path, "dataset")
if not os.path.exists(datapath):
if rank == 0:
os.makedirs(datapath)
schema = pa.schema([pa.field("tokens", pa.uint32())])
with pa.ipc.new_file(
os.path.join(datapath, "fileshard_1.arrow"), schema
) as writer:
for i in range(500):
out = list(range(i * 100, i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
os.makedirs(os.path.join(datapath, "subfolder"))
with pa.ipc.new_file(
os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema
) as writer:
for i in range(500):
out = list(range(50000 + i * 100, 50000 + i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
else:
# Give other ranks time for worker 0 to finish
time.sleep(5)

# Build dataloader
data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=40, n_logical_shards=args.logical_shards)
# Pad entries to make them batch-able
data = PreprocessDataset(data, lambda x: x + [-1]*(40-len(x)))
# Statelessly convert all outputs to tensors
data = PreprocessDataset(data, torch.tensor)
# Wrap in StatefulDataLoader
data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers)

# If checkpoint does not exist, create it
ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state")
if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0:
os.makedirs(ckpt_path, exist_ok=True)
# Iterate, assemble values to exclude
if rank == 0:
print(f"No existing checkpoint. Processing {args.n_steps} steps.")

avoid = []
for i, inp in enumerate(data):
avoid.append(inp[:,0])
if i == args.n_steps-1:
if rank == 0:
print("Iteration complete!")
save_distributed_state_dict(data, ckpt_path, mesh)
break
avoid = torch.cat(avoid)
# Get all vals onto each rank
avoid = dist.tensor.DTensor.from_local(
avoid,
mesh,
placement,
).full_tensor()

if rank == 0:
torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth"))
print(
"Generation complete! Please rerun (with different world size / workers if desired) to complete the check."
)

# If checkpoint does exist, load and finish epoch.
# Ensure all expected values are covered once epoch concludes.
else:
if rank == 0:
print("Checkpoint detected!")
load_distributed_state_dict(data, ckpt_path, mesh)
avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist()

# Finish out epoch (extra 2*ceil(ceil(n_items/n_shards)/bsize) steps to account for worst-case uneven finishing times)
vals = []
n_steps = (
math.ceil((3000 - len(avoid)) / (world_size * args.b_size))
+ 2 * math.ceil(math.ceil(3000/args.logical_shards)/args.b_size)
)
for i, inp in enumerate(data):
vals.append(inp)
if i == n_steps:
break
vals = torch.cat(vals)
# Get all vals onto each rank
vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor()

# Save final state dicts for diagnostic purposes
os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True)
torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth"))

# Perform data coverage check on rank 0 only
if rank == 0:
# Invert avoid to get expected vals
expect = []
for i in range(1000):
for offset in [0,40,80]:
if i*100+offset not in avoid:
expect.append(i*100+offset)

for x in expect:
assert x in vals, x
print("Check passed: upcoming data is covered as expected!")

dist.barrier()
dist.destroy_process_group()
150 changes: 150 additions & 0 deletions test/scalable_reader/test_scalable_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import os
import pyarrow as pa
import tempfile
import unittest

import torch

from torch.testing._internal.common_utils import TestCase

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.scalable_reader import ScalableReader, PreprocessDataset, ArrowHandler

# A set of draft unit tests for the ScalableReader.
# Note that these have not been locally tested or debugged yet (fighting my local environment),
# and likely fail in horrible ways. Mostly here for discussion/reference at this stage.

# TODO: test actual save/load distributed functions via multiprocessing

def create_temp_dir(dir=None):
# The temp dir and files within it will be released and deleted in tearDown().
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
temp_dir = tempfile.TemporaryDirectory(dir=dir) # noqa: P201
return temp_dir

class TestScalableReader(TestCase):
def setUp(self):
super().setUp()
data = create_temp_dir()
datapath = data.name
schema = pa.schema([pa.field("tokens", pa.uint32())])
with pa.ipc.new_file(
os.path.join(datapath, "fileshard_1.arrow"), schema
) as writer:
for i in range(500):
out = list(range(i * 100, i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
os.makedirs(os.path.join(datapath, "subfolder"))
with pa.ipc.new_file(
os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema
) as writer:
for i in range(500):
out = list(range(50000 + i * 100, 50000 + i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
self.datapath = datapath
self.data = data

def create_scalable(
self,
rank = 0,
worldsize = 1,
delimiter = -1,
bos = None,
seed = 42,
chunk = 1000,
logicals = 10
):
# Build dataloader
data = ScalableReader(
self.datapath,
rank,
worldsize,
ArrowHandler,
delimiter,
bos,
seed=seed,
max_chunksize=chunk,
n_logical_shards=logicals,
)
# Pad entries to make them batch-able
data = PreprocessDataset(data, lambda x: x + [-1]*(chunk-len(x)))
# Statelessly convert all outputs to tensors
data = PreprocessDataset(data, torch.tensor)
return data

def test_single_epoch(self):
for ws in [2,3,7]:
for nw in [0,2,3]:
loaderset = [iter(StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw)) for i in range(ws)]
n_steps = math.ceil(1000/ws)+10
pools = [set() for _ in loaderset]
for _ in range(n_steps):
for i,l in enumerate(loaderset):
pools[i].add(next(l)[0,0].item())
for i in range(len(pools)):
for j in range(i+1, len(pools)):
print(f"Checking outputs {i} and {j}")
overlap = len(pools[i].intersection(pools[j]))
self.assertEqual(overlap, 0, f"Overlapping data found in workers {i} and {j} (worldsize {ws}/{ws*max(nw,1)}): {overlap}")
alldata = set.union(*pools)
expected = set([x*100 for x in range(1000)])
missing = len(expected.difference(alldata))
self.assertEqual(missing, 0, f"Missing data from pool: {missing}")

def test_resumption(self):
for ws in [2,3,7]:
for nw in [0,2,3]:
loaderset = [StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw) for i in range(ws)]
loaderset2 = [StatefulDataLoader(self.create_scalable(i, ws, logicals=555), batch_size=1, num_workers=nw) for i in range(ws)]
n_steps = 2*math.ceil(1000/ws) # Proceed well into second epoch
iterset = [iter(l) for l in loaderset]
for _ in range(100):
[next(l) for l in iterset]
for i in range(ws):
loaderset2[i].load_state_dict(loaderset[i].state_dict())
iterset2 = [iter(l) for l in loaderset2]
for s in range(n_steps):
for i in range(ws):
expected = next(iterset[i])
query = next(iterset2[i])
self.assertEqual(expected, query, f"Mismatch at step 100+{s} rank {i}, (worldsize {ws}/{ws*max(nw,1)}): original {expected[0,:5]}..., recieved {query[0,:5]}")

def test_rescale_epoch(self):
nsteps = 30
for start_ws in [1,2,6]:
for end_ws in [3,4]:
for logicals in [300, 555, 721]:
# Create checkpoint
avoid = []
data = StatefulDataLoader(self.create_scalable(logicals=logicals, chunk=40), num_workers=start_ws, batch_size=1)
for i, inp in enumerate(data):
avoid.append(inp[0,0].item())
if i==(nsteps-1)*start_ws:
sd = data.state_dict()
break
# Load checkpoint
# (this step likely fails without using the custom distributed save/load checkpointing fns)
data = StatefulDataLoader(self.create_scalable(logicals=logicals, chunk=40), num_workers=end_ws, batch_size=1)
data.load_state_dict(sd)
vals = []
nsteps = math.ceil(3000 - len(avoid)) + (2*math.ceil(3000/logicals)*end_ws)
for i, inp in enumerate(data):
vals.append(inp[0,0].item())
if i == nsteps:
break
# Invert set of seen values
expect = []
for i in range(1000):
for offset in [0,40,80]:
if i*100+offset not in avoid:
expect.append(i*100+offset)
for x in expect:
self.assertObjectIn(x, vals)

Loading