@@ -115,6 +115,13 @@ class FilterConfig:
115115 projection_dim : int = 16
116116 """Projection dimension for gradient index."""
117117
118+ test_size : float = 0.05
119+
120+ tag : str = ""
121+
122+ pdbs : int = 8
123+ "Per-device batch size"
124+
118125
119126def run_sft (
120127 cfg : FilterConfig ,
@@ -145,10 +152,13 @@ def run_sft(
145152 bias = "none" ,
146153 task_type = "CAUSAL_LM" ,
147154 )
148- model = get_peft_model (model , lora_config )
149- model .print_trainable_parameters ()
155+ model = get_peft_model (model , lora_config ) # type: ignore
156+ model .print_trainable_parameters () # type: ignore
150157
151- num_train_steps = (len (train ) // 32 ) * cfg .num_epochs
158+ effective_batch_size = 32
159+ world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
160+ grad_acc_steps = effective_batch_size / world_size / cfg .pdbs
161+ num_train_steps = (len (train ) // effective_batch_size ) * cfg .num_epochs
152162 eval_steps = max (1 , num_train_steps // 10 )
153163
154164 trainer = SFTTrainer (
@@ -158,9 +168,9 @@ def run_sft(
158168 args = SFTConfig (
159169 max_length = 2048 ,
160170 output_dir = output_dir ,
161- per_device_train_batch_size = 1 ,
162- per_device_eval_batch_size = 1 ,
163- gradient_accumulation_steps = 32 ,
171+ per_device_train_batch_size = cfg . pdbs ,
172+ per_device_eval_batch_size = cfg . pdbs ,
173+ gradient_accumulation_steps = grad_acc_steps ,
164174 gradient_checkpointing = True ,
165175 learning_rate = 3e-4 ,
166176 num_train_epochs = cfg .num_epochs ,
@@ -291,6 +301,8 @@ def run_trackstar(
291301 "--nproc_per_node" ,
292302 str (num_gpus ),
293303 "--overwrite" ,
304+ "--index_cfg.precision" ,
305+ "bf16" ,
294306 ]
295307 # PEFT models need explicit tokenizer since adapter dir has no tokenizer config
296308 if args .use_lora :
@@ -332,7 +344,7 @@ def sft_full(args: FilterConfig, output_dir: str) -> str:
332344 if args .max_samples :
333345 dataset = dataset .select (range (min (args .max_samples , len (dataset ))))
334346
335- split = dataset .train_test_split (test_size = 0.05 , seed = args .seed )
347+ split = dataset .train_test_split (test_size = args . test_size , seed = args .seed )
336348 train_ds , eval_ds = split ["train" ], split ["test" ]
337349
338350 tokenizer = AutoTokenizer .from_pretrained (args .model , max_length = 8192 )
@@ -506,18 +518,16 @@ def main(
506518
507519 # Always load the original text dataset for training.
508520 # Don't shuffle here — order must match the gradient index built by bergson.
509- orig_dataset = assert_type (Dataset , load_dataset (args .dataset , split = args .split ))
521+ ds = assert_type (Dataset , load_dataset (args .dataset , split = args .split ))
510522 if args .max_samples :
511- orig_dataset = orig_dataset .select (
512- range (min (args .max_samples , len (orig_dataset )))
513- )
523+ ds = ds .select (range (min (args .max_samples , len (ds ))))
514524
515525 # Add original index column so we can map back after train_test_split shuffles
516- orig_dataset = orig_dataset .add_column ("_orig_idx" , list (range (len (orig_dataset ))))
526+ ds = ds .add_column ("_orig_idx" , list (range (len (ds ))))
517527
518528 # Split original dataset (same seed ensures consistent eval set)
519529 print ("Splitting..." )
520- orig_split = orig_dataset .train_test_split (test_size = 0.05 , seed = args .seed )
530+ orig_split = ds .train_test_split (test_size = args . test_size , seed = args .seed )
521531 orig_train , orig_eval = orig_split ["train" ], orig_split ["test" ]
522532
523533 model_name = args .model .split ("/" )[- 1 ]
@@ -526,12 +536,13 @@ def main(
526536 lora_suffix = "_lora" if args .use_lora else ""
527537 proj_suffix = f"_p{ args .projection_dim } " if args .projection_dim != 16 else ""
528538
529- if args .filter in ("attribution" , "loss" ):
530- # Step 1: SFT on the full dataset so gradients are meaningful
539+ if args .filter in ("attribution" , "loss" , "trackstar" ):
540+ # SFT on the full dataset so training statistics can be collected
531541 sft_dir = f"examples/runs/{ model_name } _{ dataset_name } _sft{ lora_suffix } "
532542 sft_model_path = sft_full (args , sft_dir )
533543
534- # Step 2: Build gradient index using the finetuned checkpoint
544+ if args .filter in ("attribution" , "loss" ):
545+ # Collect gradients and losses using the fine-tuned model
535546 if not args .index_dataset :
536547 args .index_dataset = (
537548 f"examples/runs/{ model_name } _{ dataset_name } "
@@ -541,20 +552,17 @@ def main(
541552 build_index (args , args .index_dataset , model = sft_model_path )
542553 grad_dataset = load_gradient_dataset (Path (args .index_dataset ), structured = False )
543554
544- # Split gradient dataset the same way
545- grad_split = grad_dataset .train_test_split (test_size = 0.05 , seed = args .seed )
555+ # Split resulting data to match the original train/test split
556+ grad_split = grad_dataset .train_test_split (
557+ test_size = args .test_size , seed = args .seed
558+ )
546559 grad_train = grad_split ["train" ]
547560 grad_train .set_format ("torch" )
548-
549561 elif args .filter == "trackstar" :
550- # Step 1: SFT on the full dataset so gradients are meaningful
551- sft_dir = f"examples/runs/{ model_name } _{ dataset_name } _sft{ lora_suffix } "
552- sft_model_path = sft_full (args , sft_dir )
553-
554562 # Step 2: Run trackstar pipeline for scoring
555563 trackstar_path = (
556564 f"examples/runs/{ model_name } _{ dataset_name } "
557- f"_trackstar{ lora_suffix } { proj_suffix } "
565+ f"_trackstar{ lora_suffix } { proj_suffix } { args . tag } "
558566 )
559567 run_trackstar (
560568 args ,
0 commit comments