Skip to content

Commit b23d94b

Browse files
committed
Unit Tests: Small Matrix
1 parent 0588ed0 commit b23d94b

File tree

1 file changed

+260
-0
lines changed

1 file changed

+260
-0
lines changed

tests/test_smallmatrix.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import numpy as np
4+
import pytest
5+
6+
import amrex.space3d as amr
7+
8+
9+
def test_smallmatrix():
10+
m66 = amr.SmallMatrix_6x6_F_SI1_double(
11+
[
12+
[1, 2, 3, 4, 5, 6],
13+
[7, 8, 9, 10, 11, 12],
14+
[13, 14, 15, 16, 17, 18],
15+
[19, 20, 21, 22, 23, 24],
16+
[25, 26, 27, 28, 29, 30],
17+
[31, 32, 33, 34, 35, 36],
18+
]
19+
)
20+
v = 1
21+
for j in range(1, 7):
22+
for i in range(1, 7):
23+
assert m66[i, j] == v
24+
v += 1
25+
26+
27+
def test_smallvector():
28+
cv1 = amr.SmallMatrix_6x1_F_SI1_double()
29+
rv1 = amr.SmallMatrix_1x6_F_SI1_double()
30+
cv2 = amr.SmallMatrix_6x1_F_SI1_double([1, 2, 3, 4, 5, 6])
31+
rv2 = amr.SmallMatrix_1x6_F_SI1_double([0, 10, 20, 30, 40, 50])
32+
cv3 = amr.SmallMatrix_6x1_F_SI1_double([0, 1, 2, 3, 4, 5])
33+
34+
for j in range(1, 7):
35+
assert cv1[j] == 0.0
36+
assert rv1[j] == 0.0
37+
assert cv2[j] == j
38+
assert amr.almost_equal(rv2[j], (j - 1) * 10.0)
39+
assert amr.almost_equal(cv3[j], j - 1.0)
40+
41+
42+
def test_smallmatrix_zero():
43+
zero = amr.SmallMatrix_6x6_F_SI1_double()
44+
45+
# Check properties
46+
assert zero.size == 36
47+
assert zero.row_size == 6
48+
assert zero.column_size == 6
49+
assert zero.order == "F"
50+
assert zero.starting_index == 1
51+
52+
# Check values
53+
assert zero.sum() == 0
54+
assert zero.prod() == 0
55+
assert zero.trace() == 0
56+
57+
# assign empty
58+
zeroc = amr.SmallMatrix_6x6_F_SI1_double(zero)
59+
60+
# Check values
61+
assert zeroc.sum() == 0
62+
assert zeroc.prod() == 0
63+
assert zeroc.trace() == 0
64+
65+
# create zero
66+
zerov = amr.SmallMatrix_6x6_F_SI1_double.zero()
67+
68+
# Check values
69+
assert zerov.sum() == 0
70+
assert zerov.prod() == 0
71+
assert zerov.trace() == 0
72+
73+
74+
def test_smallmatrix_identity():
75+
iden = amr.SmallMatrix_6x6_F_SI1_double.identity()
76+
77+
# Check properties
78+
assert iden.size == 36
79+
assert iden.row_size == 6
80+
assert iden.column_size == 6
81+
assert iden.order == "F"
82+
assert iden.starting_index == 1
83+
84+
# Check values
85+
assert iden.sum() == 6
86+
assert iden.prod() == 0
87+
assert iden.trace() == 6
88+
89+
90+
def test_smallmatrix_from_np():
91+
# from numpy (copy)
92+
x = np.ones(
93+
(
94+
6,
95+
6,
96+
)
97+
)
98+
print(f"\nx: {x.__array_interface__} {x.dtype}")
99+
sm = amr.SmallMatrix_6x6_F_SI1_double(x)
100+
print(f"sm: {sm.__array_interface__}")
101+
print(sm)
102+
103+
assert sm.sum() == 36
104+
assert sm.prod() == 1
105+
assert sm.trace() == 6
106+
107+
108+
def test_smallmatrix_to_np():
109+
iden = amr.SmallMatrix_6x6_F_SI1_double.identity()
110+
111+
x = iden.to_numpy()
112+
print(x)
113+
114+
assert x.sum() == 6
115+
assert x.prod() == 0
116+
assert x.trace() == 6
117+
assert not x.flags["C_CONTIGUOUS"]
118+
assert x.flags["F_CONTIGUOUS"]
119+
120+
121+
def test_smallmatrix_smallvector():
122+
v3 = amr.SmallMatrix_6x1_F_SI1_double.zero()
123+
v3[1] = 1.0
124+
v3[2] = 2.0
125+
v3[3] = 3.0
126+
v3[4] = 4.0
127+
v3[5] = 5.0
128+
v3[6] = 6.0
129+
m66 = amr.SmallMatrix_6x6_F_SI1_double.identity()
130+
r = m66 * v3
131+
132+
for i in range(1, 7):
133+
assert amr.almost_equal(r[i], v3[i])
134+
135+
136+
def test_smallmatrix_smallmatrix():
137+
A = amr.SmallMatrix_6x6_F_SI1_double(
138+
[
139+
[1, 0, 1, 0, 1, 0],
140+
[2, 1, 1, 1, 1, 2],
141+
[0, 1, 1, 1, 1, 0],
142+
[1, 1, 2, 2, 1, 1],
143+
[2, 1, 2, 2, 1, 2],
144+
[0, 1, 1, 1, 1, 0],
145+
]
146+
)
147+
B = amr.SmallMatrix_6x6_F_SI1_double(
148+
[
149+
[1, 2, 2, 2, 1, 1],
150+
[2, 3, 1, 1, 1, 3],
151+
[4, 2, 2, 2, 2, 0],
152+
[1, 4, 3, 2, 0, 1],
153+
[2, 3, 1, 0, 0, 2],
154+
[0, 1, 1, 1, 4, 0],
155+
]
156+
)
157+
C = amr.SmallMatrix_6x1_F_SI1_double([10, 8, 6, 4, 2, 0])
158+
ABC = A * B * C
159+
assert ABC[1, 1] == 322
160+
assert ABC[2, 1] == 252
161+
assert ABC[3, 1] == 388
162+
assert ABC[4, 1] == 330
163+
assert ABC[5, 1] == 310
164+
assert ABC[6, 1] == 264
165+
166+
# transpose
167+
CR = amr.SmallMatrix_1x6_F_SI1_double([10, 8, 6, 4, 2, 0])
168+
ABC_T = A.T * B.transpose_in_place() * CR.T
169+
assert ABC_T[1, 1] == 178
170+
assert ABC_T[2, 1] == 402
171+
assert ABC_T[3, 1] == 254
172+
assert ABC_T[4, 1] == 476
173+
assert ABC_T[5, 1] == 550
174+
assert ABC_T[6, 1] == 254
175+
176+
177+
def test_smallmatrix_sum_prod():
178+
m = amr.SmallMatrix_6x6_F_SI1_double()
179+
m.set_val(2.0)
180+
181+
assert m.prod() == 2 ** (m.row_size * m.column_size)
182+
assert m.sum() == 2 * m.row_size * m.column_size
183+
184+
185+
def test_smallmatrix_trace():
186+
m = amr.SmallMatrix_6x6_F_SI1_double(
187+
[
188+
[1.0, 3.4, 4.5, 5.6, 6.7, 7.8],
189+
[1.3, 2.0, 3.4, 4.5, 5.6, 6.7],
190+
[1.3, 1.0, 3.0, 4.5, 5.6, 6.7],
191+
[1.3, 1.4, 4.5, 4.0, 5.6, 6.7],
192+
[1.3, 1.0, 4.5, 5.6, 5.0, 6.7],
193+
[1.3, 1.4, 3.0, 4.5, 6.7, 6.0],
194+
]
195+
)
196+
assert m.trace() == 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0
197+
198+
199+
def test_smallmatrix_scalar():
200+
A = amr.SmallMatrix_6x6_F_SI1_double(
201+
[
202+
[+1.0, +2, +3, +4, +5, +6],
203+
[+7, +8, +9, +10, +11, +12],
204+
[+13, +14, +15, +16, +17, +18],
205+
[+19, +20, +21, +22, +23, +24],
206+
[+25, +26, +27, +28, +29, +30],
207+
[+31, +32, +33, +34, +35, +36],
208+
]
209+
)
210+
B = amr.SmallMatrix_6x6_F_SI1_double(A)
211+
B *= -1.0
212+
213+
# test matrix-scalar and scalar-matrix
214+
C = A * 2.0 + 2.0 * B
215+
assert np.allclose(C.to_numpy(), 0.0)
216+
217+
# test unary- operator and point-wise minus
218+
D = -A - B
219+
assert np.allclose(D.to_numpy(), 0.0)
220+
221+
# dot product
222+
E = amr.SmallMatrix_6x6_F_SI1_double()
223+
E.set_val(-1.0)
224+
assert A.dot(E) == -666
225+
226+
227+
def test_smallmatrix_rangecheck():
228+
cv = amr.SmallMatrix_6x1_F_SI1_double()
229+
rv = amr.SmallMatrix_1x6_F_SI1_double()
230+
m66 = amr.SmallMatrix_6x6_F_SI1_double(
231+
[
232+
[1, 2, 3, 4, 5, 6],
233+
[7, 8, 9, 10, 11, 12],
234+
[13, 14, 15, 16, 17, 18],
235+
[19, 20, 21, 22, 23, 24],
236+
[25, 26, 27, 28, 29, 30],
237+
[31, 32, 33, 34, 35, 36],
238+
]
239+
)
240+
241+
with pytest.raises(RuntimeError):
242+
cv[0]
243+
with pytest.raises(RuntimeError):
244+
cv[7]
245+
with pytest.raises(RuntimeError):
246+
rv[0]
247+
with pytest.raises(RuntimeError):
248+
rv[7]
249+
with pytest.raises(RuntimeError):
250+
m66[0, 0]
251+
with pytest.raises(RuntimeError):
252+
m66[0, 1]
253+
with pytest.raises(RuntimeError):
254+
m66[1, 0]
255+
with pytest.raises(RuntimeError):
256+
m66[7, 7]
257+
with pytest.raises(RuntimeError):
258+
m66[6, 7]
259+
with pytest.raises(RuntimeError):
260+
m66[7, 6]

0 commit comments

Comments
 (0)