Skip to content

Commit 8277be6

Browse files
fix: sample different validation subset per causal seed for true statistical variance
1 parent 5211a48 commit 8277be6

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/hprobes/cli.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -800,11 +800,11 @@ def cmd_run(args: argparse.Namespace) -> None:
800800
print("\n Causal validation: no H-Neurons found")
801801
probe.cv_results_ = None
802802
else:
803+
import math
803804
import random as _random
805+
import torch
804806

805807
val_n = min(args.causal_samples, len(samples))
806-
val_rng = _random.Random(args.seed + 42)
807-
val_subset = val_rng.sample(samples, val_n)
808808
alphas = [0.0, 1.0, 2.0]
809809
n_seeds = args.causal_seeds
810810

@@ -818,6 +818,9 @@ def cmd_run(args: argparse.Namespace) -> None:
818818
all_per_sample = {a: [] for a in alphas}
819819

820820
for s in range(n_seeds):
821+
subset_rng = _random.Random(args.seed + 42 + s)
822+
val_subset = subset_rng.sample(samples, val_n)
823+
torch.manual_seed(args.seed + 1000 + s)
821824
h_rates, h_per_sample = _causal_validation_run(
822825
model,
823826
tokenizer,
@@ -834,8 +837,6 @@ def cmd_run(args: argparse.Namespace) -> None:
834837
all_per_sample[a].append(h_per_sample[a])
835838

836839
# Report mean +/- SE
837-
import math
838-
839840
causal_info = {"h_neurons": {}, "n_samples": val_n, "n_seeds": n_seeds}
840841
for alpha in alphas:
841842
rates = all_h_rates[alpha]

0 commit comments

Comments
 (0)