Skip to content

Commit 3351b3a

Browse files
Add optional weight canonicalization before Phase 2 similarity signals
1 parent 32409f4 commit 3351b3a

6 files changed

Lines changed: 1776 additions & 6 deletions

File tree

docs/canonicalization.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Canonicalization (Comparison-Space Hardening)
2+
3+
ProvenanceKit's weight-level signals (EAS, NLF, LEP, END, WVC) compute
4+
cosine / correlation scores in raw weight space. Raw weight space is
5+
*basis-sensitive*: two functionally identical models can produce very
6+
different scores after a cheap, function-preserving transformation such
7+
as
8+
9+
* attention-head permutation,
10+
* MLP/neuron permutation,
11+
* adjacent-layer (channel-wise) rescaling,
12+
* layer-norm gamma absorption.
13+
14+
Canonicalization is an optional pre-processing pass that aligns model B
15+
into model A's basis (heads + channels) and normalizes per-channel
16+
scales before similarity scoring. It is opt-in via the `--canonicalize`
17+
flag and disabled by default.
18+
19+
## What it does
20+
21+
| Step | Behaviour |
22+
|------|-----------|
23+
| Permutation alignment (attention) | Builds per-head signatures from Q/K/V (and the corresponding columns of the attention-output projection), solves a Hungarian assignment between A and B's heads, and applies the resulting permutation to B's Q/K/V output rows and to O's input columns. |
24+
| Permutation alignment (MLP) | Builds per-channel signatures from `up_proj` (and `gate_proj` when present) rows together with `down_proj` columns, solves a Hungarian assignment, and applies the permutation to B's `up`, `gate`, and `down` projections. |
25+
| Scale normalization (`comparison`, default) | Divides every per-channel slice by its L2 norm. Applied independently to A and B. **Non-invertible.** |
26+
| Scale normalization (`function_preserving`) | Divides W_in by the per-channel norm α and multiplies W_out by α, preserving forward-pass equivalence. Stricter and slower; offered for callers who want to reuse the canonicalized weights. |
27+
| LayerNorm gamma | Each LayerNorm/RMSNorm gamma vector is unit-normed independently so it cannot dominate cosine similarity once concatenated with other signals. |
28+
| Stability check | Per-layer cosine before vs after alignment is compared; large jumps surface in `stability_warnings` to flag partially aligned layers, architecture mismatches, or bad tensor mapping. |
29+
30+
When SciPy is unavailable the assignment falls back to a greedy
31+
max-matching solver. Pass `--canonicalize-method greedy` to force it.
32+
33+
## Important: comparison-only output
34+
35+
> Scale normalization operates in a comparison space and is not
36+
> function-preserving. The resulting representation is non-invertible
37+
> and must not be used for inference or model reconstruction.
38+
39+
> This design intentionally trades invertibility for invariance to
40+
> common evasion strategies (channel rescaling and layer-norm
41+
> absorption).
42+
43+
The canonicalizer returns `ComparisonView` objects tagged with
44+
`is_comparison_only=True`. Inference, serialization, or model-export
45+
code paths should call
46+
`provenancekit.core.canonicalization.assert_not_comparison_view` on any
47+
state-dict-shaped input as a runtime guard.
48+
49+
## CLI
50+
51+
```
52+
provenancekit compare base-model suspect-model --canonicalize --json
53+
provenancekit compare base-model suspect-model --canonicalize --canonicalize-method greedy
54+
provenancekit compare base-model suspect-model --canonicalize --no-scale-normalize
55+
provenancekit compare base-model suspect-model --canonicalize --canonicalize-scale-mode function_preserving
56+
```
57+
58+
| Flag | Meaning |
59+
|------|---------|
60+
| `--canonicalize` | Enable the pass. Off by default. |
61+
| `--canonicalize-method {hungarian,greedy}` | Assignment solver. `hungarian` requires SciPy. |
62+
| `--canonicalize-scale-mode {comparison,function_preserving}` | Scale handling. Default is `comparison` (non-invertible). |
63+
| `--no-scale-normalize` | Skip per-channel scale normalization. |
64+
| `--no-permutation-align` | Skip head / channel permutation alignment. |
65+
66+
## JSON output additions
67+
68+
When `--canonicalize` is set, `compare`'s JSON output gains a
69+
`canonicalization` section:
70+
71+
```json
72+
"canonicalization": {
73+
"enabled": true,
74+
"method": "hungarian",
75+
"scale_mode": "comparison",
76+
"non_invertible": true,
77+
"layers_aligned": 32,
78+
"attention_heads_aligned": 1024,
79+
"mlp_channels_aligned": 11008,
80+
"scale_normalized": true,
81+
"unsupported_layers": [],
82+
"stability_warnings": [],
83+
"skipped_reason": null
84+
}
85+
```
86+
87+
`non_invertible: true` is intentional and not cosmetic. When
88+
`scale_mode` is `comparison`, downstream consumers must treat the
89+
canonicalized representation as a comparison artifact only.
90+
91+
## Limitations
92+
93+
Canonicalization reduces false negatives caused by permutation and
94+
scale symmetries in functionally equivalent weight representations. It
95+
**does not** prove lineage under distillation, retraining, model
96+
merging, or behavioral imitation. Treat it as one additional defence
97+
against trivial obfuscation, not as a behavioral fingerprint.
98+
99+
* Architecture metadata for both models must be compatible (same head
100+
count, same head dimension, same intermediate size). When metadata is
101+
missing or mismatched, alignment is skipped for that layer and
102+
recorded in `unsupported_layers`.
103+
* Streaming-only loads (very large models) cannot be canonicalized
104+
in-place at comparison time; the report is returned with
105+
`skipped_reason="state_dict_unavailable"`.
106+
* Pairwise canonicalization is performed at comparison time. Cached
107+
feature bundles remain unchanged so existing cache keys stay valid.

src/provenancekit/cli.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from rich.console import Console # noqa: E402
3939

4040
from provenancekit.config.settings import Settings # noqa: E402
41+
from provenancekit.core.canonicalization import ( # noqa: E402
42+
CanonicalizationConfig,
43+
)
4144
from provenancekit.core.results.formatters import ( # noqa: E402
4245
format_json,
4346
format_plain,
@@ -76,6 +79,63 @@ def _unit_float(value: str) -> float:
7679
return fvalue
7780

7881

82+
def _add_canonicalize_flags(p: argparse.ArgumentParser) -> None:
83+
"""Attach the shared canonicalization flag group to *p*."""
84+
p.add_argument(
85+
"--canonicalize",
86+
dest="canonicalize",
87+
action="store_true",
88+
help=(
89+
"Apply optional comparison-space weight canonicalization "
90+
"(attention-head + MLP-channel alignment, per-channel scale "
91+
"normalization) before computing weight-level signals. "
92+
"Reduces false negatives from permutation/scale evasions. "
93+
"Default scale_mode is 'comparison' and is non-invertible."
94+
),
95+
)
96+
p.add_argument(
97+
"--canonicalize-method",
98+
dest="canonicalize_method",
99+
choices=["hungarian", "greedy"],
100+
default="hungarian",
101+
help="Assignment solver for permutation alignment.",
102+
)
103+
p.add_argument(
104+
"--canonicalize-scale-mode",
105+
dest="canonicalize_scale_mode",
106+
choices=["comparison", "function_preserving"],
107+
default="comparison",
108+
help=(
109+
"Per-channel scale handling. 'comparison' (default) is "
110+
"non-invertible. 'function_preserving' preserves the forward "
111+
"pass and is stricter / slower."
112+
),
113+
)
114+
p.add_argument(
115+
"--no-scale-normalize",
116+
dest="canonicalize_no_scale",
117+
action="store_true",
118+
help="Disable scale normalization within canonicalization.",
119+
)
120+
p.add_argument(
121+
"--no-permutation-align",
122+
dest="canonicalize_no_perm",
123+
action="store_true",
124+
help="Disable head/MLP permutation alignment within canonicalization.",
125+
)
126+
127+
128+
def _build_canonicalization(args: argparse.Namespace) -> CanonicalizationConfig:
129+
"""Build a :class:`CanonicalizationConfig` from CLI flags."""
130+
return CanonicalizationConfig(
131+
enabled=getattr(args, "canonicalize", False),
132+
align_permutations=not getattr(args, "canonicalize_no_perm", False),
133+
normalize_scales=not getattr(args, "canonicalize_no_scale", False),
134+
method=getattr(args, "canonicalize_method", "hungarian"),
135+
scale_mode=getattr(args, "canonicalize_scale_mode", "comparison"),
136+
)
137+
138+
79139
def _build_parser() -> argparse.ArgumentParser:
80140
"""Construct the top-level argument parser."""
81141
parser = argparse.ArgumentParser(
@@ -137,6 +197,7 @@ def _build_parser() -> argparse.ArgumentParser:
137197
help="Allow execution of model-hosted Python code (config/tokenizer). "
138198
"Use only with models you trust.",
139199
)
200+
_add_canonicalize_flags(cmp)
140201

141202
scn = sub.add_parser("scan", help="Scan a model against known models")
142203
scn.add_argument(
@@ -199,6 +260,7 @@ def _build_parser() -> argparse.ArgumentParser:
199260
help="Allow execution of model-hosted Python code (config/tokenizer). "
200261
"Use only with models you trust.",
201262
)
263+
_add_canonicalize_flags(scn)
202264

203265
dl = sub.add_parser(
204266
"download-deepsignals-fingerprint",
@@ -279,12 +341,14 @@ def _run_compare(args: argparse.Namespace) -> int:
279341
scanner = ModelProvenanceScanner(settings=settings, cache=cache)
280342

281343
use_json = getattr(args, "output_json", False)
344+
canon_cfg = _build_canonicalization(args)
282345
result = _safe_run(
283346
_run_with_spinner,
284347
"Comparing models…",
285348
scanner.compare,
286349
args.model_a,
287350
args.model_b,
351+
canonicalization=canon_cfg,
288352
use_json=use_json,
289353
)
290354
if result is None:

0 commit comments

Comments
 (0)