Skip to content

Commit cde783b

Browse files
committed
save
1 parent bdc9319 commit cde783b

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

bergson/trackstar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
IndexConfig,
77
PreprocessConfig,
88
ScoreConfig,
9-
TrackstarConfig,
9+
TrackStarConfig,
1010
)
1111
from .process_grads import mix_preconditioners
1212
from .score.score import score_dataset
@@ -48,7 +48,7 @@ def trackstar(
4848
index_cfg: IndexConfig,
4949
score_cfg: ScoreConfig,
5050
preprocess_cfg: PreprocessConfig,
51-
trackstar_cfg: TrackstarConfig,
51+
trackstar_cfg: TrackStarConfig,
5252
):
5353
"""Run the full trackstar pipeline: preconditioners -> mix -> build -> score."""
5454
run_path = index_cfg.run_path
@@ -57,7 +57,6 @@ def trackstar(
5757
mixed_preconditioner_path = f"{run_path}/mixed_preconditioner"
5858
query_path = f"{run_path}/query"
5959
scores_path = f"{run_path}/scores"
60-
resume = trackstar_cfg.resume
6160

6261
# Steps 1-2 only compute preconditioners, so don't preprocess grads.
6362
precond_preprocess_cfg = PreprocessConfig()

examples/filter_data.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ class FilterConfig:
115115
projection_dim: int = 16
116116
"""Projection dimension for gradient index."""
117117

118+
test_size: float = 0.05
119+
120+
tag: str = ""
121+
122+
pdbs: int = 8
123+
"Per-device batch size"
124+
118125

119126
def run_sft(
120127
cfg: FilterConfig,
@@ -145,10 +152,13 @@ def run_sft(
145152
bias="none",
146153
task_type="CAUSAL_LM",
147154
)
148-
model = get_peft_model(model, lora_config)
149-
model.print_trainable_parameters()
155+
model = get_peft_model(model, lora_config) # type: ignore
156+
model.print_trainable_parameters() # type: ignore
150157

151-
num_train_steps = (len(train) // 32) * cfg.num_epochs
158+
effective_batch_size = 32
159+
world_size = int(os.environ.get("WORLD_SIZE", 1))
160+
grad_acc_steps = effective_batch_size / world_size / cfg.pdbs
161+
num_train_steps = (len(train) // effective_batch_size) * cfg.num_epochs
152162
eval_steps = max(1, num_train_steps // 10)
153163

154164
trainer = SFTTrainer(
@@ -158,9 +168,9 @@ def run_sft(
158168
args=SFTConfig(
159169
max_length=2048,
160170
output_dir=output_dir,
161-
per_device_train_batch_size=1,
162-
per_device_eval_batch_size=1,
163-
gradient_accumulation_steps=32,
171+
per_device_train_batch_size=cfg.pdbs,
172+
per_device_eval_batch_size=cfg.pdbs,
173+
gradient_accumulation_steps=grad_acc_steps,
164174
gradient_checkpointing=True,
165175
learning_rate=3e-4,
166176
num_train_epochs=cfg.num_epochs,
@@ -291,6 +301,8 @@ def run_trackstar(
291301
"--nproc_per_node",
292302
str(num_gpus),
293303
"--overwrite",
304+
"--index_cfg.precision",
305+
"bf16",
294306
]
295307
# PEFT models need explicit tokenizer since adapter dir has no tokenizer config
296308
if args.use_lora:
@@ -332,7 +344,7 @@ def sft_full(args: FilterConfig, output_dir: str) -> str:
332344
if args.max_samples:
333345
dataset = dataset.select(range(min(args.max_samples, len(dataset))))
334346

335-
split = dataset.train_test_split(test_size=0.05, seed=args.seed)
347+
split = dataset.train_test_split(test_size=args.test_size, seed=args.seed)
336348
train_ds, eval_ds = split["train"], split["test"]
337349

338350
tokenizer = AutoTokenizer.from_pretrained(args.model, max_length=8192)
@@ -506,18 +518,16 @@ def main(
506518

507519
# Always load the original text dataset for training.
508520
# Don't shuffle here — order must match the gradient index built by bergson.
509-
orig_dataset = assert_type(Dataset, load_dataset(args.dataset, split=args.split))
521+
ds = assert_type(Dataset, load_dataset(args.dataset, split=args.split))
510522
if args.max_samples:
511-
orig_dataset = orig_dataset.select(
512-
range(min(args.max_samples, len(orig_dataset)))
513-
)
523+
ds = ds.select(range(min(args.max_samples, len(ds))))
514524

515525
# Add original index column so we can map back after train_test_split shuffles
516-
orig_dataset = orig_dataset.add_column("_orig_idx", list(range(len(orig_dataset))))
526+
ds = ds.add_column("_orig_idx", list(range(len(ds))))
517527

518528
# Split original dataset (same seed ensures consistent eval set)
519529
print("Splitting...")
520-
orig_split = orig_dataset.train_test_split(test_size=0.05, seed=args.seed)
530+
orig_split = ds.train_test_split(test_size=args.test_size, seed=args.seed)
521531
orig_train, orig_eval = orig_split["train"], orig_split["test"]
522532

523533
model_name = args.model.split("/")[-1]
@@ -526,12 +536,13 @@ def main(
526536
lora_suffix = "_lora" if args.use_lora else ""
527537
proj_suffix = f"_p{args.projection_dim}" if args.projection_dim != 16 else ""
528538

529-
if args.filter in ("attribution", "loss"):
530-
# Step 1: SFT on the full dataset so gradients are meaningful
539+
if args.filter in ("attribution", "loss", "trackstar"):
540+
# SFT on the full dataset so training statistics can be collected
531541
sft_dir = f"examples/runs/{model_name}_{dataset_name}_sft{lora_suffix}"
532542
sft_model_path = sft_full(args, sft_dir)
533543

534-
# Step 2: Build gradient index using the finetuned checkpoint
544+
if args.filter in ("attribution", "loss"):
545+
# Collect gradients and losses using the fine-tuned model
535546
if not args.index_dataset:
536547
args.index_dataset = (
537548
f"examples/runs/{model_name}_{dataset_name}"
@@ -541,20 +552,17 @@ def main(
541552
build_index(args, args.index_dataset, model=sft_model_path)
542553
grad_dataset = load_gradient_dataset(Path(args.index_dataset), structured=False)
543554

544-
# Split gradient dataset the same way
545-
grad_split = grad_dataset.train_test_split(test_size=0.05, seed=args.seed)
555+
# Split resulting data to match the original train/test split
556+
grad_split = grad_dataset.train_test_split(
557+
test_size=args.test_size, seed=args.seed
558+
)
546559
grad_train = grad_split["train"]
547560
grad_train.set_format("torch")
548-
549561
elif args.filter == "trackstar":
550-
# Step 1: SFT on the full dataset so gradients are meaningful
551-
sft_dir = f"examples/runs/{model_name}_{dataset_name}_sft{lora_suffix}"
552-
sft_model_path = sft_full(args, sft_dir)
553-
554562
# Step 2: Run trackstar pipeline for scoring
555563
trackstar_path = (
556564
f"examples/runs/{model_name}_{dataset_name}"
557-
f"_trackstar{lora_suffix}{proj_suffix}"
565+
f"_trackstar{lora_suffix}{proj_suffix}{args.tag}"
558566
)
559567
run_trackstar(
560568
args,

0 commit comments

Comments
 (0)