-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_torch.py
More file actions
174 lines (128 loc) · 5.17 KB
/
test_torch.py
File metadata and controls
174 lines (128 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Tests for PyTorch mHC module.
Tests verify:
1. SinkhornKnopp produces doubly stochastic matrices
2. Gradients flow through all operations
3. mHCResidual and mHCBlock work correctly
4. Initialization produces near-identity behavior
"""
import pytest
import torch
import torch.nn as nn
from mhc.torch_module import SinkhornKnopp, mHCResidual, mHCBlock, create_mhc_mlp
class TestSinkhornKnopp:
"""Tests for differentiable Sinkhorn-Knopp."""
def test_produces_doubly_stochastic(self):
"""Output should have row and column sums close to 1."""
sinkhorn = SinkhornKnopp(iterations=20)
M = torch.randn(4, 4)
P = sinkhorn(M)
assert torch.allclose(P.sum(dim=1), torch.ones(4), atol=0.01)
assert torch.allclose(P.sum(dim=0), torch.ones(4), atol=0.01)
def test_non_negative_entries(self):
"""Output should have all non-negative entries."""
sinkhorn = SinkhornKnopp(iterations=20)
M = torch.randn(4, 4)
P = sinkhorn(M)
assert (P >= 0).all()
def test_differentiable(self):
"""Gradients should flow through Sinkhorn."""
sinkhorn = SinkhornKnopp(iterations=20)
M = torch.randn(4, 4, requires_grad=True)
P = sinkhorn(M)
loss = P.sum()
loss.backward()
assert M.grad is not None
assert M.grad.shape == M.shape
def test_batched_input(self):
"""Should work with batched input."""
sinkhorn = SinkhornKnopp(iterations=20)
M = torch.randn(8, 4, 4) # Batch of 8 matrices
P = sinkhorn(M)
assert P.shape == (8, 4, 4)
# Check each matrix in batch
for i in range(8):
assert torch.allclose(P[i].sum(dim=1), torch.ones(4), atol=0.01)
class TestMHCResidual:
"""Tests for mHCResidual module."""
def test_output_shape(self):
"""Output should match input shape."""
mhc = mHCResidual(dim=64, n_streams=4)
x = torch.randn(8, 4, 64)
layer_out = torch.randn(8, 64)
y = mhc(x, layer_out)
assert y.shape == x.shape
def test_gradients_flow(self):
"""Gradients should flow to input and all parameters."""
mhc = mHCResidual(dim=64, n_streams=4)
x = torch.randn(8, 4, 64, requires_grad=True)
layer_out = torch.randn(8, 64, requires_grad=True)
y = mhc(x, layer_out)
loss = y.sum()
loss.backward()
# Check input gradients
assert x.grad is not None
assert layer_out.grad is not None
# Check parameter gradients
assert mhc.H_res.grad is not None
assert mhc.alpha_res.grad is not None
def test_aggregated_input(self):
"""get_aggregated_input should reduce streams to single vector."""
mhc = mHCResidual(dim=64, n_streams=4)
x = torch.randn(8, 4, 64)
agg = mhc.get_aggregated_input(x)
assert agg.shape == (8, 64)
class TestMHCBlock:
"""Tests for mHCBlock wrapper."""
def test_wraps_linear_layer(self):
"""Should correctly wrap a linear layer."""
layer = nn.Linear(64, 64)
block = mHCBlock(layer, dim=64, n_streams=4)
x = torch.randn(8, 4, 64)
y = block(x)
assert y.shape == x.shape
def test_gradients_to_wrapped_layer(self):
"""Gradients should flow to the wrapped layer."""
layer = nn.Linear(64, 64)
block = mHCBlock(layer, dim=64, n_streams=4)
x = torch.randn(8, 4, 64, requires_grad=True)
y = block(x)
loss = y.sum()
loss.backward()
# Wrapped layer should have gradients
assert layer.weight.grad is not None
assert layer.bias.grad is not None
class TestInitialization:
"""Tests for proper initialization."""
def test_alpha_values_start_small(self):
"""Alpha values should start small (0.01)."""
mhc = mHCResidual(dim=64, n_streams=4)
assert mhc.alpha_res.item() == pytest.approx(0.01)
assert mhc.alpha_pre.item() == pytest.approx(0.01)
assert mhc.alpha_post.item() == pytest.approx(0.01)
def test_initial_behavior_near_identity(self):
"""Initial mHC should behave close to identity (small perturbation)."""
mhc = mHCResidual(dim=64, n_streams=4)
x = torch.randn(8, 4, 64)
layer_out = torch.zeros(8, 64) # Zero layer output
y = mhc(x, layer_out)
# With small alphas and zero layer output, y should be close to x
# (identity + small mixing perturbation)
diff = (y - x).abs().mean()
assert diff < 1.0 # Should be relatively small
class TestCreateMHCMLP:
"""Tests for the MLP creation helper."""
def test_creates_correct_structure(self):
"""Should create an MLP with correct number of blocks."""
mlp = create_mhc_mlp(dim=64, n_layers=3, n_streams=4)
x = torch.randn(8, 4, 64)
y = mlp(x)
assert y.shape == x.shape
def test_gradients_flow_through_mlp(self):
"""Gradients should flow through entire MLP."""
mlp = create_mhc_mlp(dim=64, n_layers=3, n_streams=4)
x = torch.randn(8, 4, 64, requires_grad=True)
y = mlp(x)
loss = y.sum()
loss.backward()
assert x.grad is not None