@@ -45,6 +45,9 @@ class AsymmetricConfig:
4545 seed : int = 42
4646 # HuggingFace dataset repo. If set, skips local generation and downloads from HF.
4747 hf_dataset : str | None = None
48+ # Template split for train/test segregation (only used for local generation)
49+ # Train uses templates < train_template_cutoff, eval majority uses templates >= cutoff
50+ train_template_cutoff : int = 5
4851
4952
5053def create_asymmetric_dataset (
@@ -57,6 +60,11 @@ def create_asymmetric_dataset(
5760 - Exclusive facts: only appear in dominant style (for testing semantic matching)
5861 - Shared facts: appear in both styles (for style ratio control)
5962
63+ For train/test segregation:
64+ - Training uses templates < train_template_cutoff (default: 0-4)
65+ - Eval majority style control uses templates >= cutoff (default: 5+)
66+ This ensures no exact text overlap between train and eval majority control.
67+
6068 Args:
6169 config: Experiment configuration.
6270 output_dir: Directory to save datasets.
@@ -103,29 +111,41 @@ def create_asymmetric_dataset(
103111 dominant_ds = style_datasets [config .dominant_style ]
104112 minority_ds = style_datasets [config .minority_style ]
105113
106- # Get unique facts (by fact text)
107- all_facts = list (set (original ["fact" ]))
108- n_facts = len (all_facts )
114+ # Get unique (identifier, field) pairs - these represent underlying semantic facts
115+ # Each pair has multiple templates (different surface forms of the same fact)
116+ semantic_facts = list (
117+ {(row ["identifier" ], row ["field" ]) for row in original }
118+ )
119+ n_semantic_facts = len (semantic_facts )
109120
110- # Split into exclusive (dominant-only) and shared
121+ # Split into exclusive (dominant-only) and shared by semantic fact
111122 rng = np .random .default_rng (config .seed )
112- rng .shuffle (all_facts )
113-
114- n_exclusive = int (n_facts * config .exclusive_ratio )
115- exclusive_facts = set (all_facts [:n_exclusive ])
116- shared_facts = set (all_facts [n_exclusive :])
117-
118- print (f"Total unique facts: { n_facts } " )
119- print (f"Exclusive to { config .dominant_style } : { len (exclusive_facts )} " )
120- print (f"Shared between styles: { len (shared_facts )} " )
121-
122- # Build training set
123- # 1. All dominant style facts
124- train_dominant = dominant_ds
123+ rng .shuffle (semantic_facts )
124+
125+ n_exclusive = int (n_semantic_facts * config .exclusive_ratio )
126+ exclusive_semantic_facts = set (semantic_facts [:n_exclusive ])
127+ shared_semantic_facts = set (semantic_facts [n_exclusive :])
128+
129+ print (f"Total unique semantic facts (identifier, field pairs): { n_semantic_facts } " )
130+ print (f"Exclusive to { config .dominant_style } : { len (exclusive_semantic_facts )} " )
131+ print (f"Shared between styles: { len (shared_semantic_facts )} " )
132+ print (f"Template cutoff for train/eval split: { config .train_template_cutoff } " )
133+
134+ # Build training set with template filtering
135+ # 1. Dominant style: only templates < cutoff (to reserve rest for eval majority control)
136+ train_dominant_indices = [
137+ i
138+ for i , row in enumerate (dominant_ds )
139+ if row ["template" ] < config .train_template_cutoff
140+ ]
141+ train_dominant = dominant_ds .select (train_dominant_indices )
125142
126- # 2. Minority style only for shared facts
143+ # 2. Minority style only for shared facts (any template since minority eval is different)
127144 minority_shared_indices = [
128- i for i , row in enumerate (minority_ds ) if row ["fact" ] in shared_facts
145+ i
146+ for i , row in enumerate (minority_ds )
147+ if (row ["identifier" ], row ["field" ]) in shared_semantic_facts
148+ and row ["template" ] < config .train_template_cutoff
129149 ]
130150 train_minority = minority_ds .select (minority_shared_indices )
131151
@@ -148,10 +168,14 @@ def create_asymmetric_dataset(
148168 print (f" Dominant ratio: { len (train_dominant ) / len (train_ds ):.2%} " )
149169
150170 # Build eval set: query exclusive facts in minority style
171+ # Use templates >= cutoff to ensure no overlap with train
151172 # These facts don't exist in minority style in training, so the model
152173 # must use semantic matching (not style matching) to find them
153174 eval_minority_indices = [
154- i for i , row in enumerate (minority_ds ) if row ["fact" ] in exclusive_facts
175+ i
176+ for i , row in enumerate (minority_ds )
177+ if (row ["identifier" ], row ["field" ]) in exclusive_semantic_facts
178+ and row ["template" ] >= config .train_template_cutoff
155179 ]
156180 eval_ds = minority_ds .select (eval_minority_indices )
157181 eval_ds = eval_ds .add_column ("style" , [config .minority_style ] * len (eval_ds ))
@@ -164,6 +188,7 @@ def create_asymmetric_dataset(
164188 print ("\n Eval set:" )
165189 print (f" Queries in { config .minority_style } style: { len (eval_ds )} " )
166190 print (f" Ground truth only in { config .dominant_style } style" )
191+ print (f" Using templates >= { config .train_template_cutoff } (no overlap with train)" )
167192
168193 # Save datasets
169194 train_ds .save_to_disk (str (train_path ))
@@ -1048,40 +1073,81 @@ def compute_asymmetric_metrics_with_pca(
10481073def create_majority_style_eval (
10491074 config : AsymmetricConfig ,
10501075 base_path : Path | str ,
1051- ) -> Path :
1076+ force_regenerate : bool = False ,
1077+ ) -> tuple [Path , bool ]:
10521078 """Create eval set using majority style (control for style mismatch).
10531079
10541080 Instead of using minority style queries, uses dominant style queries
10551081 for the exclusive facts. This shows baseline performance without style mismatch.
10561082
1083+ IMPORTANT: Uses templates >= train_template_cutoff to ensure NO overlap with
1084+ training data. This provides a proper train/test split where eval majority
1085+ style items test semantic matching (same fact, different surface form) rather
1086+ than exact text matching.
1087+
10571088 Args:
10581089 config: Experiment configuration.
10591090 base_path: Base path for experiment outputs.
1091+ force_regenerate: If True, regenerate even if cached version exists.
10601092
10611093 Returns:
1062- Path to the majority style eval dataset.
1094+ Tuple of (path to the majority style eval dataset, has_leakage flag).
1095+ has_leakage is True if there's train/test overlap (e.g., from HF data).
10631096 """
10641097 base_path = Path (base_path )
10651098 data_path = base_path / "data"
10661099 majority_eval_path = data_path / "eval_majority_style.hf"
10671100
1068- if majority_eval_path .exists ():
1101+ # Check for existing cached version
1102+ if majority_eval_path .exists () and not force_regenerate :
10691103 print (f"Loading cached majority style eval from { majority_eval_path } " )
1070- return majority_eval_path
1104+
1105+ # Check for train/test leakage by comparing reworded texts
1106+ train_ds = load_from_disk (str (data_path / "train.hf" ))
1107+ majority_eval_ds = load_from_disk (str (majority_eval_path ))
1108+ if isinstance (train_ds , DatasetDict ):
1109+ train_ds = train_ds ["train" ]
1110+ if isinstance (majority_eval_ds , DatasetDict ):
1111+ majority_eval_ds = majority_eval_ds ["train" ]
1112+
1113+ train_reworded = set (train_ds ["reworded" ])
1114+ eval_reworded = set (majority_eval_ds ["reworded" ])
1115+ overlap = train_reworded & eval_reworded
1116+ has_leakage = len (overlap ) > 0
1117+
1118+ if has_leakage :
1119+ print (
1120+ f" WARNING: { len (overlap )} /{ len (eval_reworded )} eval items have "
1121+ "exact text match in train (train/test leakage)"
1122+ )
1123+ print (" Use force_regenerate=True with local data to fix" )
1124+
1125+ return majority_eval_path , has_leakage
10711126
10721127 print ("Creating majority style eval set (control)..." )
10731128
1074- # Load the minority style eval to get the facts
1129+ # Check if local styled datasets exist for proper template segregation
1130+ local_styled_path = Path (
1131+ f"data/facts_dataset_{ config .dominant_style } -Qwen3-8B-Base.hf"
1132+ )
1133+ if not local_styled_path .exists ():
1134+ print (
1135+ f" WARNING: Local styled dataset not found at { local_styled_path } "
1136+ )
1137+ print (" Cannot create properly segregated majority eval" )
1138+ print (" Using HF eval_majority_style (may have train/test leakage)" )
1139+ return majority_eval_path , True # Return existing HF version with leakage flag
1140+
1141+ # Load the minority style eval to get the semantic facts (identifier, field pairs)
10751142 eval_ds = load_from_disk (str (data_path / "eval.hf" ))
10761143 if isinstance (eval_ds , DatasetDict ):
10771144 eval_ds = eval_ds ["train" ]
10781145
1079- eval_facts = set (eval_ds ["fact" ])
1146+ # Get semantic facts from eval (identifier, field pairs)
1147+ eval_semantic_facts = {(row ["identifier" ], row ["field" ]) for row in eval_ds }
10801148
10811149 # Load dominant style dataset
1082- dominant_ds = load_from_disk (
1083- f"data/facts_dataset_{ config .dominant_style } -Qwen3-8B-Base.hf"
1084- )
1150+ dominant_ds = load_from_disk (str (local_styled_path ))
10851151 if isinstance (dominant_ds , DatasetDict ):
10861152 dominant_ds = dominant_ds ["train" ]
10871153
@@ -1096,24 +1162,33 @@ def create_majority_style_eval(
10961162 restored_col = [fact_to_meta [row ["fact" ]][col ] for row in dominant_ds ]
10971163 dominant_ds = dominant_ds .add_column (col , restored_col )
10981164
1099- # Select only the exclusive facts (same facts as in minority eval)
1165+ # Select dominant style versions of eval semantic facts
1166+ # Use templates >= cutoff to ensure NO overlap with training data
11001167 dominant_eval_indices = [
1101- i for i , row in enumerate (dominant_ds ) if row ["fact" ] in eval_facts
1168+ i
1169+ for i , row in enumerate (dominant_ds )
1170+ if (row ["identifier" ], row ["field" ]) in eval_semantic_facts
1171+ and row ["template" ] >= config .train_template_cutoff
11021172 ]
11031173 majority_eval_ds = dominant_ds .select (dominant_eval_indices )
11041174
1105- # Add style columns
1106- majority_eval_ds = majority_eval_ds .add_column (
1107- "style" , [config .dominant_style ] * len (majority_eval_ds )
1108- )
1109- majority_eval_ds = majority_eval_ds .add_column (
1110- "expected_match_style" , [config .dominant_style ] * len (majority_eval_ds )
1111- )
1175+ print (f" Using templates >= { config .train_template_cutoff } (no overlap with train)" )
1176+ print (f" Found { len (majority_eval_ds )} majority style eval samples" )
1177+
1178+ # Add style columns if not present
1179+ if "style" not in majority_eval_ds .column_names :
1180+ majority_eval_ds = majority_eval_ds .add_column (
1181+ "style" , [config .dominant_style ] * len (majority_eval_ds )
1182+ )
1183+ if "expected_match_style" not in majority_eval_ds .column_names :
1184+ majority_eval_ds = majority_eval_ds .add_column (
1185+ "expected_match_style" , [config .dominant_style ] * len (majority_eval_ds )
1186+ )
11121187
11131188 majority_eval_ds .save_to_disk (str (majority_eval_path ))
11141189 print (f"Saved majority style eval to { majority_eval_path } " )
11151190
1116- return majority_eval_path
1191+ return majority_eval_path , False # No leakage with proper segregation
11171192
11181193
11191194def score_majority_style_eval (
@@ -1146,7 +1221,9 @@ def score_majority_style_eval(
11461221 data_path = base_path / "data"
11471222
11481223 # Create majority style eval if needed
1149- create_majority_style_eval (config , base_path )
1224+ _ , has_leakage = create_majority_style_eval (config , base_path )
1225+ if has_leakage :
1226+ print (" Note: Majority control may show inflated accuracy due to train/test leakage" )
11501227
11511228 # Determine output path
11521229 if preconditioner_name :
@@ -1291,7 +1368,7 @@ def compute_majority_style_metrics(
12911368 data_path = base_path / "data"
12921369
12931370 # Create majority style eval if needed
1294- create_majority_style_eval (config , base_path )
1371+ _ , _ = create_majority_style_eval (config , base_path )
12951372
12961373 # Load datasets
12971374 train_ds = load_from_disk (str (data_path / "train.hf" ))
@@ -1445,7 +1522,7 @@ def score_summed_eval(
14451522 eval_minority_ds = eval_minority_ds ["train" ]
14461523
14471524 # Create majority style eval if needed
1448- create_majority_style_eval (config , base_path )
1525+ _ , _ = create_majority_style_eval (config , base_path )
14491526 eval_majority_ds = load_from_disk (str (data_path / "eval_majority_style.hf" ))
14501527 if isinstance (eval_majority_ds , DatasetDict ):
14511528 eval_majority_ds = eval_majority_ds ["train" ]
@@ -1455,17 +1532,21 @@ def score_summed_eval(
14551532 f"Scoring { n_eval } summed eval queries (minority + majority) against { n_train } train samples"
14561533 )
14571534
1458- # Build fact-to-index mapping for alignment
1459- minority_facts = eval_minority_ds ["fact" ]
1460- majority_facts = eval_majority_ds ["fact" ]
1461- majority_fact_to_idx = {f : i for i , f in enumerate (majority_facts )}
1535+ # Build semantic fact mapping for alignment (identifier, field pairs)
1536+ # This works even when templates differ between minority and majority eval
1537+ minority_semantic_facts = [
1538+ (row ["identifier" ], row ["field" ]) for row in eval_minority_ds
1539+ ]
1540+ majority_semantic_to_idx = {
1541+ (row ["identifier" ], row ["field" ]): i for i , row in enumerate (eval_majority_ds )
1542+ }
14621543
1463- # Verify alignment
1464- assert len (minority_facts ) == len (
1465- majority_facts
1544+ # Verify alignment by semantic fact
1545+ assert len (eval_minority_ds ) == len (
1546+ eval_majority_ds
14661547 ), "Eval datasets must have same size"
1467- for f in minority_facts :
1468- assert f in majority_fact_to_idx , f"Fact { f } not found in majority eval"
1548+ for sf in minority_semantic_facts :
1549+ assert sf in majority_semantic_to_idx , f"Semantic fact { sf } not found in majority eval"
14691550
14701551 # Load train gradients
14711552 print ("Loading train gradients..." )
@@ -1575,14 +1656,16 @@ def score_summed_eval(
15751656 majority_grads = load_gradients (eval_majority_grads_path , structured = True )
15761657
15771658 # Sum gradients: for each eval fact, sum minority + majority style gradients
1578- # Need to align by fact since ordering might differ
1659+ # Align by semantic fact (identifier, field) since templates may differ
15791660 summed_grad_list = []
15801661 for name in tqdm (module_names , desc = "Summing eval grads" ):
15811662 g_minority = torch .from_numpy (_load_gradients_as_float (minority_grads , name ))
15821663 g_majority = torch .from_numpy (_load_gradients_as_float (majority_grads , name ))
15831664
1584- # Align majority grads to minority fact order
1585- aligned_majority_indices = [majority_fact_to_idx [f ] for f in minority_facts ]
1665+ # Align majority grads to minority semantic fact order
1666+ aligned_majority_indices = [
1667+ majority_semantic_to_idx [sf ] for sf in minority_semantic_facts
1668+ ]
15861669 g_majority_aligned = g_majority [aligned_majority_indices ]
15871670
15881671 # Sum the gradients
@@ -2019,6 +2102,21 @@ def run_asymmetric_experiment(
20192102 print ("\n " + "-" * 60 )
20202103 print ("SEMANTIC-ONLY EVAL (gradients only from answer tokens)" )
20212104 print ("-" * 60 )
2105+
2106+ # Standard influence function approach: semantic mask + H_train preconditioner
2107+ # This is the "correct" way to compute influence functions
2108+ print ("\n --- Strategy: semantic_index (standard IF with H_train) ---" )
2109+ metrics = compute_asymmetric_metrics (
2110+ config ,
2111+ base_path ,
2112+ "index" , # H_train - the standard IF preconditioner
2113+ damping_factor = damping_factor ,
2114+ eval_prompt_column = "question" ,
2115+ eval_completion_column = "answer" ,
2116+ )
2117+ print_metrics (metrics , "semantic_index" )
2118+ all_metrics ["semantic_index" ] = metrics
2119+
20222120 print ("\n --- Strategy: semantic_no_precond ---" )
20232121 metrics = compute_asymmetric_metrics (
20242122 config ,
@@ -2193,8 +2291,14 @@ def score_with_inner_product(
21932291 if isinstance (eval_majority_ds , DatasetDict ):
21942292 eval_majority_ds = eval_majority_ds ["train" ]
21952293
2196- minority_facts = eval_minority_ds ["fact" ]
2197- majority_fact_to_idx = {f : i for i , f in enumerate (eval_majority_ds ["fact" ])}
2294+ # Use semantic fact alignment (identifier, field) since templates may differ
2295+ minority_semantic_facts = [
2296+ (row ["identifier" ], row ["field" ]) for row in eval_minority_ds
2297+ ]
2298+ majority_semantic_to_idx = {
2299+ (row ["identifier" ], row ["field" ]): i
2300+ for i , row in enumerate (eval_majority_ds )
2301+ }
21982302
21992303 summed_grad_list = []
22002304 for name in tqdm (module_names , desc = "Summing eval grads" ):
@@ -2205,7 +2309,9 @@ def score_with_inner_product(
22052309 _load_gradients_as_float (majority_grads , name )
22062310 )
22072311
2208- aligned_majority_indices = [majority_fact_to_idx [f ] for f in minority_facts ]
2312+ aligned_majority_indices = [
2313+ majority_semantic_to_idx [sf ] for sf in minority_semantic_facts
2314+ ]
22092315 g_majority_aligned = g_majority [aligned_majority_indices ]
22102316
22112317 g_summed = g_minority + g_majority_aligned
0 commit comments