Skip to content

Commit cf481d8

Browse files
committed
extract out distributed logic
1 parent dea6cc7 commit cf481d8

File tree

6 files changed

+359
-278
lines changed

6 files changed

+359
-278
lines changed

bergson/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from simple_parsing import ArgumentParser, ConflictResolution
66

7+
from .build import build
78
from .data import IndexConfig, QueryConfig
8-
from .launch import build, query
9+
from .query import query
910

1011

1112
@dataclass

bergson/build.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import json
2+
import os
3+
import shutil
4+
from dataclasses import asdict
5+
from datetime import timedelta
6+
7+
import torch
8+
import torch.distributed as dist
9+
from datasets import Dataset, IterableDataset
10+
from tqdm.auto import tqdm
11+
12+
from bergson.collection import collect_gradients
13+
from bergson.data import IndexConfig, allocate_batches
14+
from bergson.utils import assert_type
15+
from bergson.worker_utils import setup_model_and_peft
16+
17+
from .launch import launch_distributed_run
18+
from .worker_utils import create_processor, setup_data_pipeline
19+
20+
21+
def build_worker(
22+
rank: int,
23+
world_size: int,
24+
cfg: IndexConfig,
25+
ds: Dataset | IterableDataset,
26+
):
27+
torch.cuda.set_device(rank)
28+
29+
# These should be set by the main process
30+
if world_size > 1:
31+
addr = os.environ.get("MASTER_ADDR", "localhost")
32+
port = os.environ.get("MASTER_PORT", "29500")
33+
34+
dist.init_process_group(
35+
"nccl",
36+
init_method=f"tcp://{addr}:{port}",
37+
device_id=torch.device(f"cuda:{rank}"),
38+
rank=rank,
39+
timeout=timedelta(hours=1),
40+
world_size=world_size,
41+
)
42+
43+
model, target_modules = setup_model_and_peft(cfg, rank)
44+
processor = create_processor(cfg, rank)
45+
46+
attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}
47+
48+
kwargs = {
49+
"model": model,
50+
"data": ds,
51+
"processor": processor,
52+
"cfg": cfg,
53+
"target_modules": target_modules,
54+
"attention_cfgs": attention_cfgs,
55+
}
56+
57+
if isinstance(ds, Dataset):
58+
batches = allocate_batches(ds["length"], cfg.token_batch_size)
59+
kwargs["batches"] = batches
60+
collect_gradients(**kwargs)
61+
else:
62+
# Convert each shard to a Dataset then map over its gradients
63+
buf, shard_id = [], 0
64+
65+
def flush(kwargs):
66+
nonlocal buf, shard_id
67+
if not buf:
68+
return
69+
ds_shard = assert_type(Dataset, Dataset.from_list(buf))
70+
batches = allocate_batches(ds_shard["length"][:], cfg.token_batch_size)
71+
kwargs["ds"] = ds_shard
72+
kwargs["batches"] = batches
73+
collect_gradients(**kwargs)
74+
75+
buf.clear()
76+
shard_id += 1
77+
78+
for ex in tqdm(ds, desc="Collecting gradients"):
79+
buf.append(ex)
80+
if len(buf) == cfg.stream_shard_size:
81+
flush(kwargs=kwargs)
82+
83+
flush(kwargs=kwargs) # Final flush
84+
if rank == 0:
85+
processor.save(cfg.partial_run_path)
86+
87+
88+
def build(cfg: IndexConfig):
89+
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
90+
with (cfg.partial_run_path / "index_config.json").open("w") as f:
91+
json.dump(asdict(cfg), f, indent=2)
92+
93+
ds = setup_data_pipeline(cfg)
94+
95+
launch_distributed_run("build", build_worker, [cfg, ds])
96+
97+
shutil.move(cfg.partial_run_path, cfg.run_path)

bergson/launch.py

Lines changed: 0 additions & 228 deletions
Original file line numberDiff line numberDiff line change
@@ -1,214 +1,10 @@
1-
import json
2-
import os
3-
import shutil
41
import socket
5-
from dataclasses import asdict
6-
from datetime import timedelta
7-
from pathlib import Path
82
from typing import Any, Callable
93

10-
import pandas as pd
114
import torch
125
import torch.distributed as dist
136
import torch.multiprocessing as mp
14-
from datasets import (
15-
Dataset,
16-
DatasetDict,
17-
IterableDataset,
18-
IterableDatasetDict,
19-
load_dataset,
20-
)
217
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
22-
from tqdm.auto import tqdm
23-
from transformers import (
24-
AutoTokenizer,
25-
)
26-
27-
from bergson.collection import collect_gradients
28-
from bergson.data import (
29-
DataConfig,
30-
IndexConfig,
31-
QueryConfig,
32-
allocate_batches,
33-
tokenize,
34-
)
35-
from bergson.query import get_query_grads
36-
from bergson.score_writer import MemmapScoreWriter
37-
from bergson.scorer import get_scorer
38-
from bergson.utils import assert_type
39-
from bergson.worker_utils import create_processor, setup_model_and_peft
40-
41-
42-
def estimate_advantage(ds: Dataset, cfg: DataConfig):
43-
"""Group rollouts by prompt and estimate advantages."""
44-
df = ds.select_columns([cfg.prompt_column, cfg.reward_column]).to_pandas()
45-
df = assert_type(pd.DataFrame, df)
46-
47-
advantages = df[cfg.reward_column] - df.groupby(cfg.prompt_column)[
48-
cfg.reward_column
49-
].transform("mean")
50-
51-
return advantages.tolist()
52-
53-
54-
def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset:
55-
"""Handle data loading and preprocessing"""
56-
57-
data_str = cfg.data.dataset
58-
if data_str.endswith(".csv"):
59-
ds = assert_type(Dataset, Dataset.from_csv(data_str))
60-
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
61-
ds = assert_type(Dataset, Dataset.from_json(data_str))
62-
else:
63-
try:
64-
ds = load_dataset(
65-
data_str, split=cfg.data.split, streaming=cfg.data.streaming
66-
)
67-
68-
if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
69-
raise NotImplementedError(
70-
"DatasetDicts and IterableDatasetDicts are not supported."
71-
)
72-
except ValueError as e:
73-
# Automatically use load_from_disk if appropriate
74-
if "load_from_disk" in str(e):
75-
ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
76-
else:
77-
raise e
78-
79-
# In many cases the token_batch_size may be smaller than the max length allowed by
80-
# the model. If cfg.data.truncation is True, we use the tokenizer to truncate
81-
tokenizer = AutoTokenizer.from_pretrained(cfg.model, revision=cfg.revision)
82-
tokenizer.model_max_length = min(tokenizer.model_max_length, cfg.token_batch_size)
83-
84-
remove_columns = ds.column_names if cfg.drop_columns else None
85-
86-
ds = ds.map(
87-
tokenize,
88-
batched=True,
89-
fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
90-
remove_columns=remove_columns,
91-
)
92-
93-
if cfg.data.reward_column:
94-
assert isinstance(ds, Dataset), "Dataset required for advantage estimation"
95-
ds = ds.add_column(
96-
"advantage",
97-
estimate_advantage(ds, cfg.data),
98-
new_fingerprint="advantage", # type: ignore
99-
)
100-
101-
return ds
102-
103-
104-
def worker(
105-
rank: int,
106-
world_size: int,
107-
cfg: IndexConfig,
108-
ds: Dataset | IterableDataset,
109-
query_cfg: QueryConfig | None = None,
110-
):
111-
torch.cuda.set_device(rank)
112-
113-
# These should be set by the main process
114-
if world_size > 1:
115-
addr = os.environ.get("MASTER_ADDR", "localhost")
116-
port = os.environ.get("MASTER_PORT", "29500")
117-
118-
dist.init_process_group(
119-
"nccl",
120-
init_method=f"tcp://{addr}:{port}",
121-
device_id=torch.device(f"cuda:{rank}"),
122-
rank=rank,
123-
timeout=timedelta(hours=1),
124-
world_size=world_size,
125-
)
126-
127-
model, target_modules = setup_model_and_peft(cfg, rank)
128-
processor = create_processor(cfg, rank)
129-
130-
attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}
131-
132-
if query_cfg is not None:
133-
query_grads = get_query_grads(
134-
query_cfg, torch.device(f"cuda:{rank}"), torch.float32
135-
)
136-
num_scores = len(query_grads[query_cfg.modules[0]])
137-
138-
kwargs = {
139-
"model": model,
140-
"data": ds,
141-
"processor": processor,
142-
"cfg": cfg,
143-
"target_modules": target_modules,
144-
"attention_cfgs": attention_cfgs,
145-
}
146-
147-
if isinstance(ds, Dataset):
148-
batches = allocate_batches(ds["length"], cfg.token_batch_size)
149-
kwargs["batches"] = batches
150-
151-
if query_cfg is not None:
152-
score_writer = MemmapScoreWriter(
153-
Path(query_cfg.scores_path),
154-
len(ds),
155-
num_scores,
156-
rank=rank,
157-
)
158-
scorer = get_scorer(
159-
query_grads,
160-
query_cfg,
161-
score_writer,
162-
cfg.module_wise,
163-
torch.device(f"cuda:{rank}"),
164-
torch.float32,
165-
)
166-
kwargs["scorer"] = scorer
167-
168-
collect_gradients(**kwargs)
169-
else:
170-
# Convert each shard to a Dataset then map over its gradients
171-
buf, shard_id = [], 0
172-
173-
def flush(kwargs):
174-
nonlocal buf, shard_id
175-
if not buf:
176-
return
177-
ds_shard = assert_type(Dataset, Dataset.from_list(buf))
178-
batches = allocate_batches(ds_shard["length"][:], cfg.token_batch_size)
179-
kwargs["ds"] = ds_shard
180-
kwargs["batches"] = batches
181-
182-
if query_cfg is not None:
183-
score_writer = MemmapScoreWriter(
184-
Path(query_cfg.scores_path) / f"shard-{shard_id:05d}",
185-
len(ds_shard),
186-
num_scores,
187-
rank=rank,
188-
)
189-
scorer = get_scorer(
190-
query_grads,
191-
query_cfg,
192-
score_writer,
193-
cfg.module_wise,
194-
torch.device(f"cuda:{rank}"),
195-
torch.float32,
196-
)
197-
kwargs["scorer"] = scorer
198-
199-
collect_gradients(**kwargs)
200-
201-
buf.clear()
202-
shard_id += 1
203-
204-
for ex in tqdm(ds, desc="Collecting gradients"):
205-
buf.append(ex)
206-
if len(buf) == cfg.stream_shard_size:
207-
flush(kwargs=kwargs)
208-
209-
flush(kwargs=kwargs) # Final flush
210-
if rank == 0:
211-
processor.save(cfg.partial_run_path)
2128

2139

21410
def dist_worker(
@@ -266,27 +62,3 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An
26662
finally:
26763
if ctx is not None:
26864
ctx.close() # Kill any processes that are still running
269-
270-
271-
def build(cfg: IndexConfig):
272-
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
273-
with (cfg.partial_run_path / "index_config.json").open("w") as f:
274-
json.dump(asdict(cfg), f, indent=2)
275-
276-
ds = setup_data_pipeline(cfg)
277-
278-
launch_distributed_run("build", worker, [ds, cfg])
279-
280-
shutil.move(cfg.partial_run_path, cfg.run_path)
281-
282-
283-
def query(cfg: IndexConfig, query_cfg: QueryConfig):
284-
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
285-
with (cfg.partial_run_path / "index_config.json").open("w") as f:
286-
json.dump(asdict(cfg), f, indent=2)
287-
288-
ds = setup_data_pipeline(cfg)
289-
290-
launch_distributed_run("query", worker, [ds, cfg, query_cfg])
291-
292-
shutil.move(cfg.partial_run_path, cfg.run_path)

0 commit comments

Comments
 (0)