Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescalability via IBM dataset layers #1372

Closed
wants to merge 71 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
45b0bce
Add distributed datasets
daviswer Nov 22, 2024
e486614
Formatting, commenting
daviswer Nov 22, 2024
10e45b9
Add demo script
daviswer Nov 23, 2024
10a6f66
Datapath None
daviswer Nov 23, 2024
0281897
Shift dummydata seeding to setup, dummy path handling
daviswer Nov 23, 2024
a175c3c
Actually create dummy data folders
daviswer Nov 23, 2024
957a5bf
Remove cfg ref
daviswer Nov 23, 2024
2e9bdf0
Remove double () call
daviswer Nov 23, 2024
e475eec
Fix dist checkpoint import
daviswer Nov 23, 2024
eac8ef6
Check ckp subfolder existence, not working folder
daviswer Nov 23, 2024
afd0169
Save vals for checking
daviswer Nov 23, 2024
031d67c
Load dummy gen state always
daviswer Nov 23, 2024
d9a575b
Setup calls in dummy
daviswer Nov 23, 2024
157f90b
Diag print
daviswer Nov 23, 2024
91f1b14
Remove sampling
daviswer Nov 23, 2024
b3569e3
Path in dummy build
daviswer Nov 23, 2024
0faea8c
Path in dummy build
daviswer Nov 23, 2024
0be44e4
Scalable off
daviswer Nov 23, 2024
c54aed2
Build data folder early
daviswer Nov 23, 2024
a16ffb1
Avoid resetting gen each state dict call
daviswer Nov 23, 2024
b645aea
Diag print off, all datasets on
daviswer Nov 23, 2024
ceffd24
Stop saving vals
daviswer Nov 23, 2024
d2eb12e
Attempt single blob save
daviswer Jan 14, 2025
ada91ec
Attempt single blob load
daviswer Jan 14, 2025
9bf8f3d
Prevent loading in place
daviswer Jan 14, 2025
934d37b
Cleanup
daviswer Jan 14, 2025
8d0cfd8
ScalableReader changes
daviswer Feb 6, 2025
e633e60
Fix datapath folder creation
daviswer Feb 6, 2025
1f2e37a
Create datapath subfolder, data only when nonexistent
daviswer Feb 6, 2025
0acdf05
Build data only rank 0
daviswer Feb 6, 2025
d146017
Pad chunks to make batchable
daviswer Feb 6, 2025
0fd38e8
give time for data to construct
daviswer Feb 6, 2025
e000b81
Fix pad fn
daviswer Feb 6, 2025
5bbd0d1
reader yield list not tensor
daviswer Feb 6, 2025
888bc19
No arg for repl placement
daviswer Feb 6, 2025
9c1699d
typo fix
daviswer Feb 6, 2025
c551a07
De-dtensorfy in load
daviswer Feb 6, 2025
4675681
Full tensor (apparently replicated doesn't force on load)
daviswer Feb 6, 2025
65744ac
Shard load, full tensor sendaround
daviswer Feb 6, 2025
88ab3c7
Chunksize 40
daviswer Feb 6, 2025
a34a5fc
Intermediate diag mkdir
daviswer Feb 6, 2025
763f60e
Time for other ranks to save
daviswer Feb 6, 2025
476c5a6
exist ok diag subf
daviswer Feb 6, 2025
ba00c20
Corrected step counting
daviswer Feb 6, 2025
0fd2b15
Fix followup nstep scaling
daviswer Feb 10, 2025
fcfee89
diag print
daviswer Feb 10, 2025
57164ca
diag print2
daviswer Feb 10, 2025
068ab32
diag print3
daviswer Feb 10, 2025
dd7d569
diag print4
daviswer Feb 10, 2025
7fa868f
diag print5
daviswer Feb 10, 2025
473e9ff
diag print6
daviswer Feb 10, 2025
bf22ce9
diag print7
daviswer Feb 10, 2025
8307e15
Diag save
daviswer Feb 10, 2025
444547f
Diag save2
daviswer Feb 10, 2025
c94b4ae
Flattenang
daviswer Feb 10, 2025
53a89b5
Flattenang 2
daviswer Feb 10, 2025
ad72ca0
Flattenang 3
daviswer Feb 10, 2025
c267675
Diag print (sigh)
daviswer Feb 10, 2025
03b4b3a
Diag print (sigh)2
daviswer Feb 10, 2025
da5991b
Attempt key-free load impl
daviswer Feb 19, 2025
9037800
Allow full run
daviswer Feb 19, 2025
5f10ac1
Direct import
daviswer Feb 19, 2025
8931620
Precise import
daviswer Feb 19, 2025
3a6e255
gloo backend
daviswer Feb 19, 2025
ba96958
Diag print
daviswer Feb 19, 2025
3ffb475
Specify keys
daviswer Feb 19, 2025
95cf494
Set constructor
daviswer Feb 19, 2025
4a592b7
Avoid popping keys mid iter
daviswer Feb 19, 2025
c37b8ba
Diag print
daviswer Feb 19, 2025
0b09fd4
diag print off
daviswer Feb 19, 2025
71b78dc
Clean up and comment out
daviswer Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions examples/ibm_rescaling/rescaling_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import argparse
import math
import os
import pyarrow as pa
import time
import torch
from torch import distributed as dist

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.ibm_rescalable import (
ArrowHandler,
PreprocessDataset,
ScalableReader,
load_distributed_state_dict,
save_distributed_state_dict,
)

# This example script validates the rescaling behavior of the ScalableReader.
# On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location.
# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers)
# and verifies that all remaining data is covered by the time the epoch finishes.

# Example usage:
# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6

# Do not change the number of steps between the first and second runs!

parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints")
parser.add_argument("--ckpt_path", type=str, default="./rescale_test")
parser.add_argument(
"--logical_shards",
type=int,
default=350,
help="Total number of data partitions. Must exceed (worldsize * n_workers) but not n_docs (1000).",
)
parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device")
parser.add_argument("--b_size", type=int, default=2, help="Number of data points per step per device")
parser.add_argument("--n_steps", type=int, default=50, help="Number of steps to take before saving. (n_steps * b_size * worldsize) cannot exceed number of items in epoch (3000)")
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()


# Setup
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
dist.init_process_group(backend="gloo")
mesh = dist.device_mesh.init_device_mesh("cpu", [world_size])
placement = [dist.tensor.placement_types.Shard(0)]

# Check input args
assert args.logical_shards >= world_size*args.num_workers, f"Logical shards {args.logical_shards} cannot be less than total workers {world_size*args.num_workers}"
assert args.logical_shards <= 1000, f"Logical shards {args.logical_shards} cannot exceed number of documents 1000"
assert args.n_steps*args.b_size*world_size < 3000, f"Number of items drawn before saving {args.n_steps*args.b_size*world_size} cannot exceed number of document chunks 3000."

# Build dataset
datapath = os.path.join(args.ckpt_path, "dataset")
if not os.path.exists(datapath):
if rank == 0:
os.makedirs(datapath)
schema = pa.schema([pa.field("tokens", pa.uint32())])
with pa.ipc.new_file(
os.path.join(datapath, "fileshard_1.arrow"), schema
) as writer:
for i in range(500):
out = list(range(i * 100, i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
os.makedirs(os.path.join(datapath, "subfolder"))
with pa.ipc.new_file(
os.path.join(datapath, "subfolder/fileshard_2.arrow"), schema
) as writer:
for i in range(500):
out = list(range(50000 + i * 100, 50000 + i * 100 + 100))
writer.write(pa.record_batch([out], schema=schema))
else:
# Give other ranks time for worker 0 to finish
time.sleep(5)

# Build dataloader
data = ScalableReader(datapath, rank, world_size, ArrowHandler, -1, seed=args.seed, max_chunksize=40, n_logical_shards=args.logical_shards)
# Pad entries to make them batch-able
data = PreprocessDataset(data, lambda x: x + [-1]*(40-len(x)))
# Statelessly convert all outputs to tensors
data = PreprocessDataset(data, torch.tensor)
# Wrap in StatefulDataLoader
data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers)

# If checkpoint does not exist, create it
ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state")
if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0:
os.makedirs(ckpt_path, exist_ok=True)
# Iterate, assemble values to exclude
if rank == 0:
print(f"No existing checkpoint. Processing {args.n_steps} steps.")

avoid = []
for i, inp in enumerate(data):
avoid.append(inp[:,0])
if i == args.n_steps-1:
if rank == 0:
print("Iteration complete!")
save_distributed_state_dict(data, ckpt_path, mesh)
break
avoid = torch.cat(avoid)
# Get all vals onto each rank
avoid = dist.tensor.DTensor.from_local(
avoid,
mesh,
placement,
).full_tensor()

if rank == 0:
torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth"))
print(
"Generation complete! Please rerun (with different world size / workers if desired) to complete the check."
)

# If checkpoint does exist, load and finish epoch.
# Ensure all expected values are covered once epoch concludes.
else:
if rank == 0:
print("Checkpoint detected!")
load_distributed_state_dict(data, ckpt_path, mesh)
avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist()

# Finish out epoch (extra 2*ceil(ceil(n_items/n_shards)/bsize) steps to account for worst-case uneven finishing times)
vals = []
n_steps = (
math.ceil((3000 - len(avoid)) / (world_size * args.b_size))
+ 2 * math.ceil(math.ceil(3000/args.logical_shards)/args.b_size)
)
for i, inp in enumerate(data):
vals.append(inp)
if i == n_steps:
break
vals = torch.cat(vals)
# Get all vals onto each rank
vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor()

# Save final state dicts for diagnostic purposes
os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True)
torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth"))

# Perform data coverage check on rank 0 only
if rank == 0:
# Invert avoid to get expected vals
expect = []
for i in range(1000):
for offset in [0,40,80]:
if i*100+offset not in avoid:
expect.append(i*100+offset)

for x in expect:
assert x in vals, x
print("Check passed: upcoming data is covered as expected!")

dist.barrier()
dist.destroy_process_group()
Loading