Skip to content

Commit 82351ff

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent bb3eac0 commit 82351ff

18 files changed

+299
-172
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ Mark tests requiring GPUs with `@pytest.mark.skipif(not torch.cuda.is_available(
2424

2525
### Environment Setup
2626

27-
If you use need to use a venv, create and/or activate it with `python3 -m venv .venv && source .venv/bin/activate && pip install pytest`.
27+
If you use need to use a venv, create and/or activate it with `python3 -m venv .venv && source .venv/bin/activate && pip install pytest`.

benchmarks/benchmark_bergson.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import json
66
import os
77
import sys
8-
import traceback
98
import time
9+
import traceback
1010
from dataclasses import asdict, dataclass
1111
from pathlib import Path
1212
from typing import Optional, Union
@@ -18,6 +18,13 @@
1818
from torch.distributed.fsdp import fully_shard
1919
from transformers import AutoModelForCausalLM, AutoTokenizer
2020

21+
from benchmarks.benchmark_utils import (
22+
DEFAULT_DATASET,
23+
MODEL_SPECS,
24+
get_run_path,
25+
parse_tokens,
26+
timestamp,
27+
)
2128
from bergson.build import build
2229
from bergson.collector.collector import CollectorComputer
2330
from bergson.collector.in_memory_collector import InMemoryCollector
@@ -32,13 +39,6 @@
3239
get_optimal_batch_size,
3340
)
3441
from bergson.utils.utils import assert_type, get_layer_list
35-
from benchmarks.benchmark_utils import (
36-
DEFAULT_DATASET,
37-
MODEL_SPECS,
38-
parse_tokens,
39-
timestamp,
40-
get_run_path,
41-
)
4242

4343
SCHEMA_VERSION = 1
4444
DEFAULT_TRAIN_SPLIT = "train"
@@ -70,7 +70,9 @@ class RunRecord:
7070
notes: str | None
7171
error: str | None
7272
num_gpus: int = 1 # Default for backwards compatibility
73-
token_batch_size: int | None = None # Auto-determined or configured token batch size
73+
token_batch_size: int | None = (
74+
None # Auto-determined or configured token batch size
75+
)
7476

7577

7678
@dataclass
@@ -181,7 +183,9 @@ def execute(self) -> None:
181183
run_path = (
182184
Path(self.cfg.run_path).resolve()
183185
if self.cfg.run_path
184-
else get_run_path(run_root, spec, train_tokens, eval_tokens, self.cfg.tag, num_gpus)
186+
else get_run_path(
187+
run_root, spec, train_tokens, eval_tokens, self.cfg.tag, num_gpus
188+
)
185189
)
186190

187191
start_wall = timestamp()
@@ -218,7 +222,7 @@ def execute(self) -> None:
218222
spec.hf_id, torch_dtype=torch.bfloat16, device_map=device_map
219223
)
220224

221-
model = model.cuda() # type: ignore
225+
model = model.cuda() # type: ignore
222226

223227
# Wrap model with FSDP
224228
embed = model.get_input_embeddings()
@@ -275,10 +279,12 @@ def tokenize(batch):
275279
eval_dataset = eval_dataset.map(tokenize, batched=True)
276280

277281
train_dataset.set_format(
278-
type="torch", columns=["input_ids", "attention_mask", "labels", "length"]
282+
type="torch",
283+
columns=["input_ids", "attention_mask", "labels", "length"],
279284
)
280285
eval_dataset.set_format(
281-
type="torch", columns=["input_ids", "attention_mask", "labels", "length"]
286+
type="torch",
287+
columns=["input_ids", "attention_mask", "labels", "length"],
282288
)
283289

284290
# Determine optimal token_batch_size if requested
@@ -331,8 +337,7 @@ def tokenize(batch):
331337

332338
# Create batches for CollectorComputer
333339
batches = allocate_batches(
334-
train_dataset["length"], # type: ignore
335-
optimal_token_batch_size
340+
train_dataset["length"], optimal_token_batch_size # type: ignore
336341
)
337342

338343
# Use CollectorComputer to process training data
@@ -347,7 +352,8 @@ def tokenize(batch):
347352

348353
# Concatenate all training gradients
349354
train_grads_flat = {
350-
name: torch.cat(grads, dim=0) for name, grads in train_collector.gradients.items()
355+
name: torch.cat(grads, dim=0)
356+
for name, grads in train_collector.gradients.items()
351357
}
352358

353359
reduce_time = time.perf_counter() - reduce_start
@@ -363,7 +369,9 @@ def tokenize(batch):
363369
all_scores = []
364370

365371
# Limit eval examples
366-
eval_subset = eval_dataset.select(range(min(self.cfg.max_eval_examples, len(eval_dataset))))
372+
eval_subset = eval_dataset.select(
373+
range(min(self.cfg.max_eval_examples, len(eval_dataset)))
374+
)
367375

368376
for i in range(len(eval_subset)):
369377
# Create single-example dataset
@@ -387,16 +395,17 @@ def tokenize(batch):
387395

388396
# Concatenate test gradients
389397
test_grads = {
390-
name: torch.cat(grads, dim=0) for name, grads in test_collector.gradients.items()
398+
name: torch.cat(grads, dim=0)
399+
for name, grads in test_collector.gradients.items()
391400
}
392401

393402
# Compute inner products (no normalization, no preconditioning)
394403
scores = torch.zeros(len(train_dataset), device="cpu")
395404
for name in test_grads:
396405
if name in train_grads_flat:
397-
scores += (
398-
test_grads[name] @ train_grads_flat[name].T
399-
).squeeze(0)
406+
scores += (test_grads[name] @ train_grads_flat[name].T).squeeze(
407+
0
408+
)
400409

401410
all_scores.append(scores)
402411

@@ -474,7 +483,9 @@ def execute(self) -> None:
474483
run_path = (
475484
Path(self.cfg.run_path).resolve()
476485
if self.cfg.run_path
477-
else get_run_path(run_root, spec, train_tokens, eval_tokens, self.cfg.tag, 1)
486+
else get_run_path(
487+
run_root, spec, train_tokens, eval_tokens, self.cfg.tag, 1
488+
)
478489
)
479490

480491
start_wall = timestamp()

benchmarks/benchmark_bergson_cli.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from __future__ import annotations
44

5-
import shutil
65
import json
6+
import platform
7+
import shutil
78
import subprocess
89
import sys
910
import time
@@ -12,15 +13,17 @@
1213
from typing import Optional
1314

1415
from simple_parsing import ArgumentParser, ConflictResolution, field
15-
import platform
1616

17-
from bergson.utils.auto_batch_size import determine_batch_size_cli, get_optimal_batch_size
1817
from benchmarks.benchmark_utils import (
1918
DEFAULT_DATASET,
2019
MODEL_SPECS,
20+
get_run_path,
2121
parse_tokens,
2222
timestamp,
23-
get_run_path,
23+
)
24+
from bergson.utils.auto_batch_size import (
25+
determine_batch_size_cli,
26+
get_optimal_batch_size,
2427
)
2528

2629
SCHEMA_VERSION = 1
@@ -102,6 +105,7 @@ def get_hardware_info() -> str:
102105
"""Get hardware information string."""
103106
try:
104107
import torch
108+
105109
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
106110
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
107111
return f"{platform.node()} ({gpu_count}x {gpu_name})"
@@ -128,7 +132,11 @@ def run_cli_command(cmd: list[str], description: str) -> tuple[bool, float, str]
128132
)
129133
elapsed = time.perf_counter() - start
130134
if result.returncode != 0:
131-
return False, elapsed, f"{description} failed with return code {result.returncode}"
135+
return (
136+
False,
137+
elapsed,
138+
f"{description} failed with return code {result.returncode}",
139+
)
132140
print(f"{description} completed in {elapsed:.2f}s")
133141
return True, elapsed, ""
134142
except Exception as e:
@@ -219,14 +227,19 @@ def execute(self) -> None:
219227
f" Completed at {existing_run.end_time} "
220228
f"(runtime: {existing_run.total_runtime_seconds:.1f}s)"
221229
)
222-
print(
223-
f" Use --skip_existing=False to force re-run"
224-
)
230+
print(" Use --skip_existing=False to force re-run")
225231
return
226232
benchmark_path = (
227233
Path(self.run_cfg.run_path).resolve()
228234
if self.run_cfg.run_path
229-
else get_run_path(run_root, spec, train_tokens, eval_seqs, self.run_cfg.tag, self.run_cfg.num_gpus)
235+
else get_run_path(
236+
run_root,
237+
spec,
238+
train_tokens,
239+
eval_seqs,
240+
self.run_cfg.tag,
241+
self.run_cfg.num_gpus,
242+
)
230243
)
231244

232245
# Create directories for bergson artifacts
@@ -351,9 +364,7 @@ def execute(self) -> None:
351364
end_wall = timestamp()
352365

353366
token_batch_size = (
354-
optimal_token_batch_size
355-
if self.run_cfg.auto_batch_size
356-
else None
367+
optimal_token_batch_size if self.run_cfg.auto_batch_size else None
357368
)
358369

359370
record = CLIRunRecord(

benchmarks/benchmark_dattri.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from __future__ import annotations
44

5-
import os
65
import argparse
76
import json
7+
import os
88
import sys
99
import textwrap
1010
import time
@@ -18,12 +18,15 @@
1818
from dattri.task import AttributionTask
1919
from transformers import AutoModelForCausalLM, AutoTokenizer
2020

21-
from bergson.utils.utils import assert_type
22-
2321
# Import from same directory
2422
from benchmarks.benchmark_utils import (
25-
MODEL_SPECS, DEFAULT_DATASET, parse_tokens, timestamp, get_run_path
23+
DEFAULT_DATASET,
24+
MODEL_SPECS,
25+
get_run_path,
26+
parse_tokens,
27+
timestamp,
2628
)
29+
from bergson.utils.utils import assert_type
2730

2831
SCHEMA_VERSION = 1
2932
DEFAULT_TRAIN_SPLIT = "train"
@@ -123,7 +126,9 @@ def tokenize(batch):
123126

124127
# Select enough examples
125128
total_needed = train_examples_needed + eval_examples_needed
126-
train_dataset = train_dataset.select(range(min(total_needed, len(train_dataset))))
129+
train_dataset = train_dataset.select(
130+
range(min(total_needed, len(train_dataset)))
131+
)
127132

128133
eval_dataset = train_dataset.select(
129134
range(train_examples_needed, train_examples_needed + eval_examples_needed)
@@ -140,7 +145,9 @@ def collate_fn(batch):
140145
# Dattri expects tuples of (input_ids, labels) where labels = input_ids for language modeling
141146
# Keep on CPU - dattri will handle device placement
142147
input_ids = torch.stack([item["input_ids"] for item in batch])
143-
labels = input_ids.clone() # For language modeling, labels are the same as input_ids
148+
labels = (
149+
input_ids.clone()
150+
) # For language modeling, labels are the same as input_ids
144151
return (input_ids, labels)
145152

146153
train_loader = torch.utils.data.DataLoader(
@@ -156,7 +163,7 @@ def collate_fn(batch):
156163

157164
# Get model device
158165
model_device = next(model.parameters()).device
159-
166+
160167
def loss_func(params, data_target_pair):
161168
x, y = data_target_pair
162169
# Ensure data is on the same device as model
@@ -169,7 +176,7 @@ def loss_func(params, data_target_pair):
169176
if isinstance(output, tuple):
170177
logits = output[0] # First element is logits
171178
else:
172-
logits = output.logits if hasattr(output, 'logits') else output
179+
logits = output.logits if hasattr(output, "logits") else output
173180
shift_logits = logits[:, :-1].contiguous()
174181
shift_labels = y[:, 1:].contiguous()
175182
loss = nn.CrossEntropyLoss()(
@@ -180,8 +187,8 @@ def loss_func(params, data_target_pair):
180187
# Create task
181188
task = AttributionTask(
182189
loss_func=loss_func,
183-
model=model,
184-
checkpoints=model.state_dict(),
190+
model=model,
191+
checkpoints=model.state_dict(),
185192
)
186193

187194
# Create attributor and cache
@@ -203,6 +210,7 @@ def loss_func(params, data_target_pair):
203210
status = "error"
204211
error_message = repr(exc)
205212
import traceback
213+
206214
traceback.print_exc()
207215

208216
runtime = time.perf_counter() - start
@@ -273,7 +281,9 @@ def main(argv: list[str] | None = None) -> None:
273281
)
274282
run_parser.add_argument("--batch-size", type=int, default=4)
275283
run_parser.add_argument("--max-length", type=int, default=512)
276-
run_parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs to use")
284+
run_parser.add_argument(
285+
"--num-gpus", type=int, default=1, help="Number of GPUs to use"
286+
)
277287
run_parser.add_argument("--dataset", default=DEFAULT_DATASET)
278288
run_parser.add_argument("--train-split", default=DEFAULT_TRAIN_SPLIT)
279289
run_parser.add_argument("--eval-split", default=DEFAULT_EVAL_SPLIT)

benchmarks/benchmark_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass
2-
from datetime import datetime
3-
from datetime import timezone
2+
from datetime import datetime, timezone
43
from pathlib import Path
54

65
from datasets import Dataset, load_from_disk
@@ -9,6 +8,7 @@
98
TOKENIZED_DATASET_PATH = "data/EleutherAI/SmolLM2-135M-10B-tokenized"
109
MAX_BENCHMARK_LENGTH = 1024
1110

11+
1212
@dataclass(frozen=True)
1313
class ModelSpec:
1414
key: str
@@ -44,7 +44,12 @@ def get_run_path(
4444

4545

4646
def timestamp() -> str:
47-
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
47+
return (
48+
datetime.now(timezone.utc)
49+
.replace(microsecond=0)
50+
.isoformat()
51+
.replace("+00:00", "Z")
52+
)
4853

4954

5055
def format_tokens(tokens: int) -> str:
@@ -111,7 +116,9 @@ def load_benchmark_dataset(
111116
total_tokens_before = sum(len(tokens) for tokens in ds["input_ids"])
112117
num_examples_before = len(ds)
113118

114-
print(f"Dataset loaded: {num_examples_before:,} examples, {total_tokens_before:,} tokens")
119+
print(
120+
f"Dataset loaded: {num_examples_before:,} examples, {total_tokens_before:,} tokens"
121+
)
115122

116123
# Filter to only sequences >= min_length
117124
print(f"Filtering sequences to length >= {min_length}...")
@@ -124,9 +131,11 @@ def load_benchmark_dataset(
124131
num_examples_removed = num_examples_before - num_examples_after
125132
tokens_removed = total_tokens_before - total_tokens_after
126133

127-
print(f"\nFiltered dataset:")
134+
print("\nFiltered dataset:")
128135
print(f" Examples: {num_examples_after:,} (removed {num_examples_removed:,})")
129136
print(f" Tokens: {total_tokens_after:,} (removed {tokens_removed:,})")
130-
print(f" Average length: {total_tokens_after / num_examples_after:.1f} tokens/example")
137+
print(
138+
f" Average length: {total_tokens_after / num_examples_after:.1f} tokens/example"
139+
)
131140

132141
return ds

0 commit comments

Comments
 (0)