|
1 | | -import json |
2 | | -import os |
3 | | -import shutil |
4 | 1 | import socket |
5 | | -from dataclasses import asdict |
6 | | -from datetime import timedelta |
7 | | -from pathlib import Path |
8 | 2 | from typing import Any, Callable |
9 | 3 |
|
10 | | -import pandas as pd |
11 | 4 | import torch |
12 | 5 | import torch.distributed as dist |
13 | 6 | import torch.multiprocessing as mp |
14 | | -from datasets import ( |
15 | | - Dataset, |
16 | | - DatasetDict, |
17 | | - IterableDataset, |
18 | | - IterableDatasetDict, |
19 | | - load_dataset, |
20 | | -) |
21 | 7 | from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes |
22 | | -from tqdm.auto import tqdm |
23 | | -from transformers import ( |
24 | | - AutoTokenizer, |
25 | | -) |
26 | | - |
27 | | -from bergson.collection import collect_gradients |
28 | | -from bergson.data import ( |
29 | | - DataConfig, |
30 | | - IndexConfig, |
31 | | - QueryConfig, |
32 | | - allocate_batches, |
33 | | - tokenize, |
34 | | -) |
35 | | -from bergson.query import get_query_grads |
36 | | -from bergson.score_writer import MemmapScoreWriter |
37 | | -from bergson.scorer import get_scorer |
38 | | -from bergson.utils import assert_type |
39 | | -from bergson.worker_utils import create_processor, setup_model_and_peft |
40 | | - |
41 | | - |
42 | | -def estimate_advantage(ds: Dataset, cfg: DataConfig): |
43 | | - """Group rollouts by prompt and estimate advantages.""" |
44 | | - df = ds.select_columns([cfg.prompt_column, cfg.reward_column]).to_pandas() |
45 | | - df = assert_type(pd.DataFrame, df) |
46 | | - |
47 | | - advantages = df[cfg.reward_column] - df.groupby(cfg.prompt_column)[ |
48 | | - cfg.reward_column |
49 | | - ].transform("mean") |
50 | | - |
51 | | - return advantages.tolist() |
52 | | - |
53 | | - |
54 | | -def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset: |
55 | | - """Handle data loading and preprocessing""" |
56 | | - |
57 | | - data_str = cfg.data.dataset |
58 | | - if data_str.endswith(".csv"): |
59 | | - ds = assert_type(Dataset, Dataset.from_csv(data_str)) |
60 | | - elif data_str.endswith(".json") or data_str.endswith(".jsonl"): |
61 | | - ds = assert_type(Dataset, Dataset.from_json(data_str)) |
62 | | - else: |
63 | | - try: |
64 | | - ds = load_dataset( |
65 | | - data_str, split=cfg.data.split, streaming=cfg.data.streaming |
66 | | - ) |
67 | | - |
68 | | - if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): |
69 | | - raise NotImplementedError( |
70 | | - "DatasetDicts and IterableDatasetDicts are not supported." |
71 | | - ) |
72 | | - except ValueError as e: |
73 | | - # Automatically use load_from_disk if appropriate |
74 | | - if "load_from_disk" in str(e): |
75 | | - ds = Dataset.load_from_disk(data_str, keep_in_memory=False) |
76 | | - else: |
77 | | - raise e |
78 | | - |
79 | | - # In many cases the token_batch_size may be smaller than the max length allowed by |
80 | | - # the model. If cfg.data.truncation is True, we use the tokenizer to truncate |
81 | | - tokenizer = AutoTokenizer.from_pretrained(cfg.model, revision=cfg.revision) |
82 | | - tokenizer.model_max_length = min(tokenizer.model_max_length, cfg.token_batch_size) |
83 | | - |
84 | | - remove_columns = ds.column_names if cfg.drop_columns else None |
85 | | - |
86 | | - ds = ds.map( |
87 | | - tokenize, |
88 | | - batched=True, |
89 | | - fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer), |
90 | | - remove_columns=remove_columns, |
91 | | - ) |
92 | | - |
93 | | - if cfg.data.reward_column: |
94 | | - assert isinstance(ds, Dataset), "Dataset required for advantage estimation" |
95 | | - ds = ds.add_column( |
96 | | - "advantage", |
97 | | - estimate_advantage(ds, cfg.data), |
98 | | - new_fingerprint="advantage", # type: ignore |
99 | | - ) |
100 | | - |
101 | | - return ds |
102 | | - |
103 | | - |
104 | | -def worker( |
105 | | - rank: int, |
106 | | - world_size: int, |
107 | | - cfg: IndexConfig, |
108 | | - ds: Dataset | IterableDataset, |
109 | | - query_cfg: QueryConfig | None = None, |
110 | | -): |
111 | | - torch.cuda.set_device(rank) |
112 | | - |
113 | | - # These should be set by the main process |
114 | | - if world_size > 1: |
115 | | - addr = os.environ.get("MASTER_ADDR", "localhost") |
116 | | - port = os.environ.get("MASTER_PORT", "29500") |
117 | | - |
118 | | - dist.init_process_group( |
119 | | - "nccl", |
120 | | - init_method=f"tcp://{addr}:{port}", |
121 | | - device_id=torch.device(f"cuda:{rank}"), |
122 | | - rank=rank, |
123 | | - timeout=timedelta(hours=1), |
124 | | - world_size=world_size, |
125 | | - ) |
126 | | - |
127 | | - model, target_modules = setup_model_and_peft(cfg, rank) |
128 | | - processor = create_processor(cfg, rank) |
129 | | - |
130 | | - attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules} |
131 | | - |
132 | | - if query_cfg is not None: |
133 | | - query_grads = get_query_grads( |
134 | | - query_cfg, torch.device(f"cuda:{rank}"), torch.float32 |
135 | | - ) |
136 | | - num_scores = len(query_grads[query_cfg.modules[0]]) |
137 | | - |
138 | | - kwargs = { |
139 | | - "model": model, |
140 | | - "data": ds, |
141 | | - "processor": processor, |
142 | | - "cfg": cfg, |
143 | | - "target_modules": target_modules, |
144 | | - "attention_cfgs": attention_cfgs, |
145 | | - } |
146 | | - |
147 | | - if isinstance(ds, Dataset): |
148 | | - batches = allocate_batches(ds["length"], cfg.token_batch_size) |
149 | | - kwargs["batches"] = batches |
150 | | - |
151 | | - if query_cfg is not None: |
152 | | - score_writer = MemmapScoreWriter( |
153 | | - Path(query_cfg.scores_path), |
154 | | - len(ds), |
155 | | - num_scores, |
156 | | - rank=rank, |
157 | | - ) |
158 | | - scorer = get_scorer( |
159 | | - query_grads, |
160 | | - query_cfg, |
161 | | - score_writer, |
162 | | - cfg.module_wise, |
163 | | - torch.device(f"cuda:{rank}"), |
164 | | - torch.float32, |
165 | | - ) |
166 | | - kwargs["scorer"] = scorer |
167 | | - |
168 | | - collect_gradients(**kwargs) |
169 | | - else: |
170 | | - # Convert each shard to a Dataset then map over its gradients |
171 | | - buf, shard_id = [], 0 |
172 | | - |
173 | | - def flush(kwargs): |
174 | | - nonlocal buf, shard_id |
175 | | - if not buf: |
176 | | - return |
177 | | - ds_shard = assert_type(Dataset, Dataset.from_list(buf)) |
178 | | - batches = allocate_batches(ds_shard["length"][:], cfg.token_batch_size) |
179 | | - kwargs["ds"] = ds_shard |
180 | | - kwargs["batches"] = batches |
181 | | - |
182 | | - if query_cfg is not None: |
183 | | - score_writer = MemmapScoreWriter( |
184 | | - Path(query_cfg.scores_path) / f"shard-{shard_id:05d}", |
185 | | - len(ds_shard), |
186 | | - num_scores, |
187 | | - rank=rank, |
188 | | - ) |
189 | | - scorer = get_scorer( |
190 | | - query_grads, |
191 | | - query_cfg, |
192 | | - score_writer, |
193 | | - cfg.module_wise, |
194 | | - torch.device(f"cuda:{rank}"), |
195 | | - torch.float32, |
196 | | - ) |
197 | | - kwargs["scorer"] = scorer |
198 | | - |
199 | | - collect_gradients(**kwargs) |
200 | | - |
201 | | - buf.clear() |
202 | | - shard_id += 1 |
203 | | - |
204 | | - for ex in tqdm(ds, desc="Collecting gradients"): |
205 | | - buf.append(ex) |
206 | | - if len(buf) == cfg.stream_shard_size: |
207 | | - flush(kwargs=kwargs) |
208 | | - |
209 | | - flush(kwargs=kwargs) # Final flush |
210 | | - if rank == 0: |
211 | | - processor.save(cfg.partial_run_path) |
212 | 8 |
|
213 | 9 |
|
214 | 10 | def dist_worker( |
@@ -266,27 +62,3 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An |
266 | 62 | finally: |
267 | 63 | if ctx is not None: |
268 | 64 | ctx.close() # Kill any processes that are still running |
269 | | - |
270 | | - |
271 | | -def build(cfg: IndexConfig): |
272 | | - cfg.partial_run_path.mkdir(parents=True, exist_ok=True) |
273 | | - with (cfg.partial_run_path / "index_config.json").open("w") as f: |
274 | | - json.dump(asdict(cfg), f, indent=2) |
275 | | - |
276 | | - ds = setup_data_pipeline(cfg) |
277 | | - |
278 | | - launch_distributed_run("build", worker, [ds, cfg]) |
279 | | - |
280 | | - shutil.move(cfg.partial_run_path, cfg.run_path) |
281 | | - |
282 | | - |
283 | | -def query(cfg: IndexConfig, query_cfg: QueryConfig): |
284 | | - cfg.partial_run_path.mkdir(parents=True, exist_ok=True) |
285 | | - with (cfg.partial_run_path / "index_config.json").open("w") as f: |
286 | | - json.dump(asdict(cfg), f, indent=2) |
287 | | - |
288 | | - ds = setup_data_pipeline(cfg) |
289 | | - |
290 | | - launch_distributed_run("query", worker, [ds, cfg, query_cfg]) |
291 | | - |
292 | | - shutil.move(cfg.partial_run_path, cfg.run_path) |
0 commit comments