File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -800,11 +800,11 @@ def cmd_run(args: argparse.Namespace) -> None:
800800 print ("\n Causal validation: no H-Neurons found" )
801801 probe .cv_results_ = None
802802 else :
803+ import math
803804 import random as _random
805+ import torch
804806
805807 val_n = min (args .causal_samples , len (samples ))
806- val_rng = _random .Random (args .seed + 42 )
807- val_subset = val_rng .sample (samples , val_n )
808808 alphas = [0.0 , 1.0 , 2.0 ]
809809 n_seeds = args .causal_seeds
810810
@@ -818,6 +818,9 @@ def cmd_run(args: argparse.Namespace) -> None:
818818 all_per_sample = {a : [] for a in alphas }
819819
820820 for s in range (n_seeds ):
821+ subset_rng = _random .Random (args .seed + 42 + s )
822+ val_subset = subset_rng .sample (samples , val_n )
823+ torch .manual_seed (args .seed + 1000 + s )
821824 h_rates , h_per_sample = _causal_validation_run (
822825 model ,
823826 tokenizer ,
@@ -834,8 +837,6 @@ def cmd_run(args: argparse.Namespace) -> None:
834837 all_per_sample [a ].append (h_per_sample [a ])
835838
836839 # Report mean +/- SE
837- import math
838-
839840 causal_info = {"h_neurons" : {}, "n_samples" : val_n , "n_seeds" : n_seeds }
840841 for alpha in alphas :
841842 rates = all_h_rates [alpha ]
You can’t perform that action at this time.
0 commit comments