-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathtest_constructor.py
More file actions
68 lines (48 loc) · 1.64 KB
/
test_constructor.py
File metadata and controls
68 lines (48 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pytest
import torch
from transformer_lens import FactoredMatrix
def test_factored_matrix():
A = torch.randn(5, 3)
B = torch.randn(3, 7)
f = FactoredMatrix(A, B)
assert torch.equal(f.A, A)
assert torch.equal(f.B, B)
assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
assert not f.has_leading_dims
assert f.shape == (5, 7)
def test_factored_matrix_b_leading_dims():
A = torch.ones((5, 3))
B = torch.ones((2, 4, 3, 7))
f = FactoredMatrix(A, B)
assert f.A.shape == (2, 4, 5, 3)
assert torch.equal(f.B, B)
assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
assert f.has_leading_dims
assert f.shape == (2, 4, 5, 7)
def test_factored_matrix_a_b_leading_dims():
A = torch.ones((4, 5, 3))
B = torch.ones((2, 4, 3, 7))
f = FactoredMatrix(A, B)
assert f.A.shape == (2, 4, 5, 3)
assert torch.equal(f.B, B)
assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
assert f.has_leading_dims
assert f.shape == (2, 4, 5, 7)
def test_factored_matrix_broadcast_mismatch():
A = torch.ones((9, 5, 3))
B = torch.ones((2, 4, 3, 7))
with pytest.raises(RuntimeError, match=r"[Mm]ismatch"):
FactoredMatrix(A, B)
@pytest.mark.skip(
"""
AssertionError will not be reached due to jaxtyping argument consistency
checks, which are enabled at test time but not run time.
See https://github.com/TransformerLensOrg/TransformerLens/issues/190
"""
)
def test_factored_matrix_inner_mismatch():
A = torch.ones((2, 3, 4))
B = torch.ones((2, 3, 5))
with pytest.raises(AssertionError) as e:
FactoredMatrix(A, B)
assert "inner dimension" in str(e.value)