Skip to content

Commit 94a1308

Browse files
author
David Johnston
committed
consistent train/eval template split
1 parent 5a9bd65 commit 94a1308

File tree

1 file changed

+163
-57
lines changed

1 file changed

+163
-57
lines changed

examples/semantic/asymmetric.py

Lines changed: 163 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class AsymmetricConfig:
4545
seed: int = 42
4646
# HuggingFace dataset repo. If set, skips local generation and downloads from HF.
4747
hf_dataset: str | None = None
48+
# Template split for train/test segregation (only used for local generation)
49+
# Train uses templates < train_template_cutoff, eval majority uses templates >= cutoff
50+
train_template_cutoff: int = 5
4851

4952

5053
def create_asymmetric_dataset(
@@ -57,6 +60,11 @@ def create_asymmetric_dataset(
5760
- Exclusive facts: only appear in dominant style (for testing semantic matching)
5861
- Shared facts: appear in both styles (for style ratio control)
5962
63+
For train/test segregation:
64+
- Training uses templates < train_template_cutoff (default: 0-4)
65+
- Eval majority style control uses templates >= cutoff (default: 5+)
66+
This ensures no exact text overlap between train and eval majority control.
67+
6068
Args:
6169
config: Experiment configuration.
6270
output_dir: Directory to save datasets.
@@ -103,29 +111,41 @@ def create_asymmetric_dataset(
103111
dominant_ds = style_datasets[config.dominant_style]
104112
minority_ds = style_datasets[config.minority_style]
105113

106-
# Get unique facts (by fact text)
107-
all_facts = list(set(original["fact"]))
108-
n_facts = len(all_facts)
114+
# Get unique (identifier, field) pairs - these represent underlying semantic facts
115+
# Each pair has multiple templates (different surface forms of the same fact)
116+
semantic_facts = list(
117+
{(row["identifier"], row["field"]) for row in original}
118+
)
119+
n_semantic_facts = len(semantic_facts)
109120

110-
# Split into exclusive (dominant-only) and shared
121+
# Split into exclusive (dominant-only) and shared by semantic fact
111122
rng = np.random.default_rng(config.seed)
112-
rng.shuffle(all_facts)
113-
114-
n_exclusive = int(n_facts * config.exclusive_ratio)
115-
exclusive_facts = set(all_facts[:n_exclusive])
116-
shared_facts = set(all_facts[n_exclusive:])
117-
118-
print(f"Total unique facts: {n_facts}")
119-
print(f"Exclusive to {config.dominant_style}: {len(exclusive_facts)}")
120-
print(f"Shared between styles: {len(shared_facts)}")
121-
122-
# Build training set
123-
# 1. All dominant style facts
124-
train_dominant = dominant_ds
123+
rng.shuffle(semantic_facts)
124+
125+
n_exclusive = int(n_semantic_facts * config.exclusive_ratio)
126+
exclusive_semantic_facts = set(semantic_facts[:n_exclusive])
127+
shared_semantic_facts = set(semantic_facts[n_exclusive:])
128+
129+
print(f"Total unique semantic facts (identifier, field pairs): {n_semantic_facts}")
130+
print(f"Exclusive to {config.dominant_style}: {len(exclusive_semantic_facts)}")
131+
print(f"Shared between styles: {len(shared_semantic_facts)}")
132+
print(f"Template cutoff for train/eval split: {config.train_template_cutoff}")
133+
134+
# Build training set with template filtering
135+
# 1. Dominant style: only templates < cutoff (to reserve rest for eval majority control)
136+
train_dominant_indices = [
137+
i
138+
for i, row in enumerate(dominant_ds)
139+
if row["template"] < config.train_template_cutoff
140+
]
141+
train_dominant = dominant_ds.select(train_dominant_indices)
125142

126-
# 2. Minority style only for shared facts
143+
# 2. Minority style only for shared facts (any template since minority eval is different)
127144
minority_shared_indices = [
128-
i for i, row in enumerate(minority_ds) if row["fact"] in shared_facts
145+
i
146+
for i, row in enumerate(minority_ds)
147+
if (row["identifier"], row["field"]) in shared_semantic_facts
148+
and row["template"] < config.train_template_cutoff
129149
]
130150
train_minority = minority_ds.select(minority_shared_indices)
131151

@@ -148,10 +168,14 @@ def create_asymmetric_dataset(
148168
print(f" Dominant ratio: {len(train_dominant) / len(train_ds):.2%}")
149169

150170
# Build eval set: query exclusive facts in minority style
171+
# Use templates >= cutoff to ensure no overlap with train
151172
# These facts don't exist in minority style in training, so the model
152173
# must use semantic matching (not style matching) to find them
153174
eval_minority_indices = [
154-
i for i, row in enumerate(minority_ds) if row["fact"] in exclusive_facts
175+
i
176+
for i, row in enumerate(minority_ds)
177+
if (row["identifier"], row["field"]) in exclusive_semantic_facts
178+
and row["template"] >= config.train_template_cutoff
155179
]
156180
eval_ds = minority_ds.select(eval_minority_indices)
157181
eval_ds = eval_ds.add_column("style", [config.minority_style] * len(eval_ds))
@@ -164,6 +188,7 @@ def create_asymmetric_dataset(
164188
print("\nEval set:")
165189
print(f" Queries in {config.minority_style} style: {len(eval_ds)}")
166190
print(f" Ground truth only in {config.dominant_style} style")
191+
print(f" Using templates >= {config.train_template_cutoff} (no overlap with train)")
167192

168193
# Save datasets
169194
train_ds.save_to_disk(str(train_path))
@@ -1048,40 +1073,81 @@ def compute_asymmetric_metrics_with_pca(
10481073
def create_majority_style_eval(
10491074
config: AsymmetricConfig,
10501075
base_path: Path | str,
1051-
) -> Path:
1076+
force_regenerate: bool = False,
1077+
) -> tuple[Path, bool]:
10521078
"""Create eval set using majority style (control for style mismatch).
10531079
10541080
Instead of using minority style queries, uses dominant style queries
10551081
for the exclusive facts. This shows baseline performance without style mismatch.
10561082
1083+
IMPORTANT: Uses templates >= train_template_cutoff to ensure NO overlap with
1084+
training data. This provides a proper train/test split where eval majority
1085+
style items test semantic matching (same fact, different surface form) rather
1086+
than exact text matching.
1087+
10571088
Args:
10581089
config: Experiment configuration.
10591090
base_path: Base path for experiment outputs.
1091+
force_regenerate: If True, regenerate even if cached version exists.
10601092
10611093
Returns:
1062-
Path to the majority style eval dataset.
1094+
Tuple of (path to the majority style eval dataset, has_leakage flag).
1095+
has_leakage is True if there's train/test overlap (e.g., from HF data).
10631096
"""
10641097
base_path = Path(base_path)
10651098
data_path = base_path / "data"
10661099
majority_eval_path = data_path / "eval_majority_style.hf"
10671100

1068-
if majority_eval_path.exists():
1101+
# Check for existing cached version
1102+
if majority_eval_path.exists() and not force_regenerate:
10691103
print(f"Loading cached majority style eval from {majority_eval_path}")
1070-
return majority_eval_path
1104+
1105+
# Check for train/test leakage by comparing reworded texts
1106+
train_ds = load_from_disk(str(data_path / "train.hf"))
1107+
majority_eval_ds = load_from_disk(str(majority_eval_path))
1108+
if isinstance(train_ds, DatasetDict):
1109+
train_ds = train_ds["train"]
1110+
if isinstance(majority_eval_ds, DatasetDict):
1111+
majority_eval_ds = majority_eval_ds["train"]
1112+
1113+
train_reworded = set(train_ds["reworded"])
1114+
eval_reworded = set(majority_eval_ds["reworded"])
1115+
overlap = train_reworded & eval_reworded
1116+
has_leakage = len(overlap) > 0
1117+
1118+
if has_leakage:
1119+
print(
1120+
f" WARNING: {len(overlap)}/{len(eval_reworded)} eval items have "
1121+
"exact text match in train (train/test leakage)"
1122+
)
1123+
print(" Use force_regenerate=True with local data to fix")
1124+
1125+
return majority_eval_path, has_leakage
10711126

10721127
print("Creating majority style eval set (control)...")
10731128

1074-
# Load the minority style eval to get the facts
1129+
# Check if local styled datasets exist for proper template segregation
1130+
local_styled_path = Path(
1131+
f"data/facts_dataset_{config.dominant_style}-Qwen3-8B-Base.hf"
1132+
)
1133+
if not local_styled_path.exists():
1134+
print(
1135+
f" WARNING: Local styled dataset not found at {local_styled_path}"
1136+
)
1137+
print(" Cannot create properly segregated majority eval")
1138+
print(" Using HF eval_majority_style (may have train/test leakage)")
1139+
return majority_eval_path, True # Return existing HF version with leakage flag
1140+
1141+
# Load the minority style eval to get the semantic facts (identifier, field pairs)
10751142
eval_ds = load_from_disk(str(data_path / "eval.hf"))
10761143
if isinstance(eval_ds, DatasetDict):
10771144
eval_ds = eval_ds["train"]
10781145

1079-
eval_facts = set(eval_ds["fact"])
1146+
# Get semantic facts from eval (identifier, field pairs)
1147+
eval_semantic_facts = {(row["identifier"], row["field"]) for row in eval_ds}
10801148

10811149
# Load dominant style dataset
1082-
dominant_ds = load_from_disk(
1083-
f"data/facts_dataset_{config.dominant_style}-Qwen3-8B-Base.hf"
1084-
)
1150+
dominant_ds = load_from_disk(str(local_styled_path))
10851151
if isinstance(dominant_ds, DatasetDict):
10861152
dominant_ds = dominant_ds["train"]
10871153

@@ -1096,24 +1162,33 @@ def create_majority_style_eval(
10961162
restored_col = [fact_to_meta[row["fact"]][col] for row in dominant_ds]
10971163
dominant_ds = dominant_ds.add_column(col, restored_col)
10981164

1099-
# Select only the exclusive facts (same facts as in minority eval)
1165+
# Select dominant style versions of eval semantic facts
1166+
# Use templates >= cutoff to ensure NO overlap with training data
11001167
dominant_eval_indices = [
1101-
i for i, row in enumerate(dominant_ds) if row["fact"] in eval_facts
1168+
i
1169+
for i, row in enumerate(dominant_ds)
1170+
if (row["identifier"], row["field"]) in eval_semantic_facts
1171+
and row["template"] >= config.train_template_cutoff
11021172
]
11031173
majority_eval_ds = dominant_ds.select(dominant_eval_indices)
11041174

1105-
# Add style columns
1106-
majority_eval_ds = majority_eval_ds.add_column(
1107-
"style", [config.dominant_style] * len(majority_eval_ds)
1108-
)
1109-
majority_eval_ds = majority_eval_ds.add_column(
1110-
"expected_match_style", [config.dominant_style] * len(majority_eval_ds)
1111-
)
1175+
print(f" Using templates >= {config.train_template_cutoff} (no overlap with train)")
1176+
print(f" Found {len(majority_eval_ds)} majority style eval samples")
1177+
1178+
# Add style columns if not present
1179+
if "style" not in majority_eval_ds.column_names:
1180+
majority_eval_ds = majority_eval_ds.add_column(
1181+
"style", [config.dominant_style] * len(majority_eval_ds)
1182+
)
1183+
if "expected_match_style" not in majority_eval_ds.column_names:
1184+
majority_eval_ds = majority_eval_ds.add_column(
1185+
"expected_match_style", [config.dominant_style] * len(majority_eval_ds)
1186+
)
11121187

11131188
majority_eval_ds.save_to_disk(str(majority_eval_path))
11141189
print(f"Saved majority style eval to {majority_eval_path}")
11151190

1116-
return majority_eval_path
1191+
return majority_eval_path, False # No leakage with proper segregation
11171192

11181193

11191194
def score_majority_style_eval(
@@ -1146,7 +1221,9 @@ def score_majority_style_eval(
11461221
data_path = base_path / "data"
11471222

11481223
# Create majority style eval if needed
1149-
create_majority_style_eval(config, base_path)
1224+
_, has_leakage = create_majority_style_eval(config, base_path)
1225+
if has_leakage:
1226+
print(" Note: Majority control may show inflated accuracy due to train/test leakage")
11501227

11511228
# Determine output path
11521229
if preconditioner_name:
@@ -1291,7 +1368,7 @@ def compute_majority_style_metrics(
12911368
data_path = base_path / "data"
12921369

12931370
# Create majority style eval if needed
1294-
create_majority_style_eval(config, base_path)
1371+
_, _ = create_majority_style_eval(config, base_path)
12951372

12961373
# Load datasets
12971374
train_ds = load_from_disk(str(data_path / "train.hf"))
@@ -1445,7 +1522,7 @@ def score_summed_eval(
14451522
eval_minority_ds = eval_minority_ds["train"]
14461523

14471524
# Create majority style eval if needed
1448-
create_majority_style_eval(config, base_path)
1525+
_, _ = create_majority_style_eval(config, base_path)
14491526
eval_majority_ds = load_from_disk(str(data_path / "eval_majority_style.hf"))
14501527
if isinstance(eval_majority_ds, DatasetDict):
14511528
eval_majority_ds = eval_majority_ds["train"]
@@ -1455,17 +1532,21 @@ def score_summed_eval(
14551532
f"Scoring {n_eval} summed eval queries (minority + majority) against {n_train} train samples"
14561533
)
14571534

1458-
# Build fact-to-index mapping for alignment
1459-
minority_facts = eval_minority_ds["fact"]
1460-
majority_facts = eval_majority_ds["fact"]
1461-
majority_fact_to_idx = {f: i for i, f in enumerate(majority_facts)}
1535+
# Build semantic fact mapping for alignment (identifier, field pairs)
1536+
# This works even when templates differ between minority and majority eval
1537+
minority_semantic_facts = [
1538+
(row["identifier"], row["field"]) for row in eval_minority_ds
1539+
]
1540+
majority_semantic_to_idx = {
1541+
(row["identifier"], row["field"]): i for i, row in enumerate(eval_majority_ds)
1542+
}
14621543

1463-
# Verify alignment
1464-
assert len(minority_facts) == len(
1465-
majority_facts
1544+
# Verify alignment by semantic fact
1545+
assert len(eval_minority_ds) == len(
1546+
eval_majority_ds
14661547
), "Eval datasets must have same size"
1467-
for f in minority_facts:
1468-
assert f in majority_fact_to_idx, f"Fact {f} not found in majority eval"
1548+
for sf in minority_semantic_facts:
1549+
assert sf in majority_semantic_to_idx, f"Semantic fact {sf} not found in majority eval"
14691550

14701551
# Load train gradients
14711552
print("Loading train gradients...")
@@ -1575,14 +1656,16 @@ def score_summed_eval(
15751656
majority_grads = load_gradients(eval_majority_grads_path, structured=True)
15761657

15771658
# Sum gradients: for each eval fact, sum minority + majority style gradients
1578-
# Need to align by fact since ordering might differ
1659+
# Align by semantic fact (identifier, field) since templates may differ
15791660
summed_grad_list = []
15801661
for name in tqdm(module_names, desc="Summing eval grads"):
15811662
g_minority = torch.from_numpy(_load_gradients_as_float(minority_grads, name))
15821663
g_majority = torch.from_numpy(_load_gradients_as_float(majority_grads, name))
15831664

1584-
# Align majority grads to minority fact order
1585-
aligned_majority_indices = [majority_fact_to_idx[f] for f in minority_facts]
1665+
# Align majority grads to minority semantic fact order
1666+
aligned_majority_indices = [
1667+
majority_semantic_to_idx[sf] for sf in minority_semantic_facts
1668+
]
15861669
g_majority_aligned = g_majority[aligned_majority_indices]
15871670

15881671
# Sum the gradients
@@ -2019,6 +2102,21 @@ def run_asymmetric_experiment(
20192102
print("\n" + "-" * 60)
20202103
print("SEMANTIC-ONLY EVAL (gradients only from answer tokens)")
20212104
print("-" * 60)
2105+
2106+
# Standard influence function approach: semantic mask + H_train preconditioner
2107+
# This is the "correct" way to compute influence functions
2108+
print("\n--- Strategy: semantic_index (standard IF with H_train) ---")
2109+
metrics = compute_asymmetric_metrics(
2110+
config,
2111+
base_path,
2112+
"index", # H_train - the standard IF preconditioner
2113+
damping_factor=damping_factor,
2114+
eval_prompt_column="question",
2115+
eval_completion_column="answer",
2116+
)
2117+
print_metrics(metrics, "semantic_index")
2118+
all_metrics["semantic_index"] = metrics
2119+
20222120
print("\n--- Strategy: semantic_no_precond ---")
20232121
metrics = compute_asymmetric_metrics(
20242122
config,
@@ -2193,8 +2291,14 @@ def score_with_inner_product(
21932291
if isinstance(eval_majority_ds, DatasetDict):
21942292
eval_majority_ds = eval_majority_ds["train"]
21952293

2196-
minority_facts = eval_minority_ds["fact"]
2197-
majority_fact_to_idx = {f: i for i, f in enumerate(eval_majority_ds["fact"])}
2294+
# Use semantic fact alignment (identifier, field) since templates may differ
2295+
minority_semantic_facts = [
2296+
(row["identifier"], row["field"]) for row in eval_minority_ds
2297+
]
2298+
majority_semantic_to_idx = {
2299+
(row["identifier"], row["field"]): i
2300+
for i, row in enumerate(eval_majority_ds)
2301+
}
21982302

21992303
summed_grad_list = []
22002304
for name in tqdm(module_names, desc="Summing eval grads"):
@@ -2205,7 +2309,9 @@ def score_with_inner_product(
22052309
_load_gradients_as_float(majority_grads, name)
22062310
)
22072311

2208-
aligned_majority_indices = [majority_fact_to_idx[f] for f in minority_facts]
2312+
aligned_majority_indices = [
2313+
majority_semantic_to_idx[sf] for sf in minority_semantic_facts
2314+
]
22092315
g_majority_aligned = g_majority[aligned_majority_indices]
22102316

22112317
g_summed = g_minority + g_majority_aligned

0 commit comments

Comments
 (0)