Skip to content

Commit f1a686f

Browse files
committed
feat(tests): add unit tests for ModelAnalysisManager and ParameterSnapshot functionality
1 parent 81e148b commit f1a686f

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Tests for ModelAnalysisManager."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from grail.trainer.analysis import (
7+
AnalysisConfig,
8+
ModelAnalysisManager,
9+
ParameterChangeMetrics,
10+
)
11+
12+
13+
class SimpleModel(nn.Module):
14+
"""Simple model for testing."""
15+
16+
def __init__(self):
17+
super().__init__()
18+
self.linear = nn.Linear(10, 5)
19+
20+
def forward(self, x):
21+
return self.linear(x)
22+
23+
24+
def test_manager_creation():
25+
"""Test creating a manager with factory method."""
26+
config = AnalysisConfig(interval=10)
27+
manager = ModelAnalysisManager.create(config)
28+
29+
assert len(manager) == 2 # ParameterChange + SparseQuality (both enabled by default)
30+
assert manager.step_count == 0
31+
32+
33+
def test_manager_builder_pattern():
34+
"""Test building manager with custom metrics."""
35+
config = AnalysisConfig(
36+
interval=10,
37+
param_change_enabled=False,
38+
sparse_quality_enabled=False,
39+
)
40+
41+
manager = ModelAnalysisManager(config).add_metric(ParameterChangeMetrics(thresholds=[1e-6]))
42+
43+
assert len(manager) == 1
44+
45+
46+
def test_manager_interval():
47+
"""Test that metrics are only computed at intervals."""
48+
config = AnalysisConfig(interval=5)
49+
manager = ModelAnalysisManager.create(config)
50+
model = SimpleModel()
51+
52+
# Steps 1-4: No metrics
53+
for i in range(1, 5):
54+
metrics = manager.on_optimizer_step(model)
55+
assert metrics == {}
56+
assert manager.step_count == i
57+
58+
# Step 5: First snapshot (no metrics yet)
59+
metrics = manager.on_optimizer_step(model)
60+
assert metrics == {}
61+
assert manager.step_count == 5
62+
63+
# Modify model
64+
with torch.no_grad():
65+
model.linear.weight.data += 0.1
66+
67+
# Steps 6-9: No metrics
68+
for _i in range(6, 10):
69+
metrics = manager.on_optimizer_step(model)
70+
assert metrics == {}
71+
72+
# Step 10: Metrics computed
73+
metrics = manager.on_optimizer_step(model)
74+
assert len(metrics) > 0 # Should have metrics now
75+
assert "param_change/norm_l2" in metrics
76+
77+
78+
def test_manager_reset():
79+
"""Test resetting manager state."""
80+
config = AnalysisConfig(interval=5)
81+
manager = ModelAnalysisManager.create(config)
82+
model = SimpleModel()
83+
84+
# Advance to step 10
85+
for _ in range(10):
86+
manager.on_optimizer_step(model)
87+
88+
assert manager.step_count == 10
89+
90+
# Reset
91+
manager.reset()
92+
93+
assert manager.step_count == 0
94+
assert manager.old_snapshot is None
95+
96+
97+
def test_manager_state_dict():
98+
"""Test saving and loading state."""
99+
config = AnalysisConfig(interval=5)
100+
manager = ModelAnalysisManager.create(config)
101+
model = SimpleModel()
102+
103+
# Advance to step 10
104+
for _ in range(10):
105+
manager.on_optimizer_step(model)
106+
107+
# Save state
108+
state = manager.state_dict()
109+
assert state["step_count"] == 10
110+
111+
# Create new manager and load state
112+
new_manager = ModelAnalysisManager.create(config)
113+
new_manager.load_state_dict(state)
114+
115+
assert new_manager.step_count == 10
116+
117+
118+
def test_manager_minimal_config():
119+
"""Test minimal configuration."""
120+
config = AnalysisConfig.minimal()
121+
manager = ModelAnalysisManager.create(config)
122+
123+
assert config.interval == 500
124+
assert config.param_change_enabled is True
125+
assert config.sparse_quality_enabled is False
126+
assert len(manager) == 1 # Only param change
127+
128+
129+
def test_manager_comprehensive_config():
130+
"""Test comprehensive configuration."""
131+
config = AnalysisConfig.comprehensive()
132+
133+
assert config.interval == 50
134+
assert config.param_change_enabled is True
135+
assert config.sparse_quality_enabled is True
136+
assert config.param_change_per_layer is True
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Tests for analysis primitives (ParameterSnapshot, ParameterDelta)."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from grail.trainer.analysis.primitives import ParameterSnapshot
7+
8+
9+
class SimpleModel(nn.Module):
10+
"""Simple model for testing."""
11+
12+
def __init__(self):
13+
super().__init__()
14+
self.linear1 = nn.Linear(10, 20)
15+
self.linear2 = nn.Linear(20, 5)
16+
17+
def forward(self, x):
18+
return self.linear2(torch.relu(self.linear1(x)))
19+
20+
21+
def test_parameter_snapshot_creation():
22+
"""Test creating a parameter snapshot."""
23+
model = SimpleModel()
24+
25+
snapshot = ParameterSnapshot(model)
26+
27+
assert len(snapshot) == 4 # 2 weights + 2 biases
28+
assert "linear1.weight" in snapshot
29+
assert "linear1.bias" in snapshot
30+
assert "linear2.weight" in snapshot
31+
assert "linear2.bias" in snapshot
32+
33+
# Check device and dtype
34+
assert snapshot.device == "cpu"
35+
assert snapshot.dtype == torch.float32
36+
37+
38+
def test_parameter_snapshot_immutable():
39+
"""Test that snapshot data is read-only."""
40+
model = SimpleModel()
41+
snapshot = ParameterSnapshot(model)
42+
43+
# Should not be able to modify snapshot data directly
44+
# (This is enforced by returning a dict view, not a settable attribute)
45+
original_weight = snapshot.data["linear1.weight"].clone()
46+
47+
# Modifying model should not affect snapshot
48+
model.linear1.weight.data.fill_(42.0)
49+
50+
assert not torch.allclose(snapshot.data["linear1.weight"], model.linear1.weight.data)
51+
assert torch.allclose(snapshot.data["linear1.weight"], original_weight)
52+
53+
54+
def test_parameter_delta_computation():
55+
"""Test computing delta between two snapshots."""
56+
model = SimpleModel()
57+
58+
# Take initial snapshot
59+
snapshot1 = ParameterSnapshot(model)
60+
61+
# Modify model
62+
with torch.no_grad():
63+
model.linear1.weight.data += 0.5
64+
model.linear1.bias.data -= 0.1
65+
66+
# Take new snapshot
67+
snapshot2 = ParameterSnapshot(model)
68+
69+
# Compute delta
70+
delta = snapshot1.compute_delta(snapshot2)
71+
72+
assert len(delta) == 4
73+
74+
# Check that deltas are correct
75+
assert torch.allclose(
76+
delta.deltas["linear1.weight"], torch.full_like(delta.deltas["linear1.weight"], 0.5)
77+
)
78+
assert torch.allclose(
79+
delta.deltas["linear1.bias"], torch.full_like(delta.deltas["linear1.bias"], -0.1)
80+
)
81+
assert torch.allclose(
82+
delta.deltas["linear2.weight"], torch.zeros_like(delta.deltas["linear2.weight"])
83+
)
84+
85+
86+
def test_parameter_delta_statistics():
87+
"""Test delta statistics computation."""
88+
model = SimpleModel()
89+
90+
snapshot1 = ParameterSnapshot(model)
91+
92+
# Make known changes
93+
with torch.no_grad():
94+
model.linear1.weight.data += 1.0 # 10x20 = 200 params, each +1.0
95+
model.linear1.bias.data += 2.0 # 20 params, each +2.0
96+
97+
snapshot2 = ParameterSnapshot(model)
98+
delta = snapshot1.compute_delta(snapshot2)
99+
100+
stats = delta.statistics()
101+
102+
assert "norm_l2" in stats
103+
assert "norm_l1" in stats
104+
assert "norm_max" in stats
105+
assert "mean" in stats
106+
assert "std" in stats
107+
108+
# L1 norm should be: 200*1.0 + 20*2.0 = 240
109+
expected_l1 = 200 * 1.0 + 20 * 2.0 + 100 * 0.0 + 5 * 0.0
110+
assert abs(stats["norm_l1"] - expected_l1) < 1e-5
111+
112+
# Max should be 2.0
113+
assert abs(stats["norm_max"] - 2.0) < 1e-5
114+
115+
116+
def test_parameter_delta_sparsity():
117+
"""Test sparsity computation at different thresholds."""
118+
model = SimpleModel()
119+
120+
snapshot1 = ParameterSnapshot(model)
121+
122+
# Create varied changes
123+
with torch.no_grad():
124+
model.linear1.weight.data += 1e-5 # Above 1e-6 threshold
125+
model.linear1.bias.data += 1e-10 # Below 1e-6 threshold
126+
127+
snapshot2 = ParameterSnapshot(model)
128+
delta = snapshot1.compute_delta(snapshot2)
129+
130+
sparsity_1e6 = delta.sparsity_at_threshold(1e-6)
131+
132+
# linear1.weight (200 params) should be kept (above threshold)
133+
# linear1.bias (20 params) should be dropped (below threshold)
134+
# linear2.* (105 params) should be dropped (zero)
135+
total_params = 200 + 20 + 100 + 5 # 325
136+
kept_params = 200
137+
138+
assert sparsity_1e6["total_params"] == total_params
139+
assert sparsity_1e6["kept_params"] == kept_params
140+
assert abs(sparsity_1e6["kept_ratio"] - (kept_params / total_params)) < 1e-5
141+
142+
143+
def test_parameter_delta_sparse_mask():
144+
"""Test applying sparse mask to delta."""
145+
model = SimpleModel()
146+
147+
snapshot1 = ParameterSnapshot(model)
148+
149+
with torch.no_grad():
150+
model.linear1.weight.data += 1e-5
151+
152+
snapshot2 = ParameterSnapshot(model)
153+
delta = snapshot1.compute_delta(snapshot2)
154+
155+
# Apply sparse mask at 1e-6
156+
sparse_delta = delta.apply_sparse_mask(threshold=1e-6)
157+
158+
# Check that small changes were zeroed
159+
assert torch.allclose(
160+
sparse_delta.deltas["linear1.weight"],
161+
torch.full_like(sparse_delta.deltas["linear1.weight"], 1e-5),
162+
)
163+
assert torch.allclose(
164+
sparse_delta.deltas["linear1.bias"],
165+
torch.zeros_like(sparse_delta.deltas["linear1.bias"]),
166+
)
167+
168+
169+
def test_parameter_delta_per_layer_stats():
170+
"""Test per-layer statistics computation."""
171+
model = SimpleModel()
172+
173+
snapshot1 = ParameterSnapshot(model)
174+
175+
with torch.no_grad():
176+
model.linear1.weight.data += 1.0
177+
model.linear2.weight.data += 2.0
178+
179+
snapshot2 = ParameterSnapshot(model)
180+
delta = snapshot1.compute_delta(snapshot2)
181+
182+
per_layer = delta.per_layer_statistics()
183+
184+
assert len(per_layer) == 4
185+
assert "linear1.weight" in per_layer
186+
assert "linear2.weight" in per_layer
187+
188+
# Check that means are correct
189+
assert abs(per_layer["linear1.weight"]["mean"] - 1.0) < 1e-5
190+
assert abs(per_layer["linear2.weight"]["mean"] - 2.0) < 1e-5

0 commit comments

Comments
 (0)