Skip to content

Commit ac05808

Browse files
kaushikcfdinducer
authored andcommitted
test apply_distributive_property_to_einsums
flake8: disable N806 for test_linalg Using capital letters for matrices is common
1 parent 2c443ce commit ac05808

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ multiline-quotes = """
88
99
per-file-ignores =
1010
examples/advection.py:B023
11+
test/test_linalg.py:N806
1112
1213
# enable-flake8-bugbear
1314

test/test_linalg.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#!/usr/bin/env python
2+
3+
__copyright__ = """
4+
Copyright (C) 2023 Kaushik Kulkarni
5+
"""
6+
7+
__license__ = """
8+
Permission is hereby granted, free of charge, to any person obtaining a copy
9+
of this software and associated documentation files (the "Software"), to deal
10+
in the Software without restriction, including without limitation the rights
11+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12+
copies of the Software, and to permit persons to whom the Software is
13+
furnished to do so, subject to the following conditions:
14+
15+
The above copyright notice and this permission notice shall be included in
16+
all copies or substantial portions of the Software.
17+
18+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24+
THE SOFTWARE.
25+
"""
26+
27+
import sys
28+
import numpy as np
29+
import pytato as pt
30+
from pyopencl.tools import ( # noqa: F401
31+
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
32+
33+
34+
def test_apply_einsum_distributive_law_0():
35+
from pytato.transform.einsum_distributive_law import (
36+
EinsumDistributiveLawDescriptor,
37+
DoDistribute, DoNotDistribute,
38+
apply_distributive_property_to_einsums,
39+
)
40+
41+
def how_to_distribute(
42+
expr: pt.Einsum) -> EinsumDistributiveLawDescriptor:
43+
if pt.analysis.is_einsum_similar_to_subscript(
44+
expr, "ij,j->i"):
45+
return DoDistribute(ioperand=1)
46+
else:
47+
return DoNotDistribute()
48+
49+
x1 = pt.make_placeholder("x1", 4, np.float64)
50+
x2 = pt.make_placeholder("x2", 4, np.float64)
51+
A1 = pt.make_placeholder("A1", (10, 4), np.float64)
52+
A2 = pt.make_placeholder("A2", (10, 4), np.float64)
53+
y = (7*A1 + 8*A2) @ (2*x1-3*x2)
54+
y_transformed = apply_distributive_property_to_einsums(y, how_to_distribute)
55+
56+
assert y_transformed == ((2 * ((7*A1 + 8*A2) @ x1) - 3 * ((7*A1 + 8*A2) @
57+
x2)))
58+
59+
60+
def test_apply_einsum_distributive_law_1():
61+
from pytato.transform.einsum_distributive_law import (
62+
EinsumDistributiveLawDescriptor,
63+
DoDistribute, DoNotDistribute,
64+
apply_distributive_property_to_einsums,
65+
)
66+
67+
def how_to_distribute(
68+
expr: pt.Einsum) -> EinsumDistributiveLawDescriptor:
69+
if pt.analysis.is_einsum_similar_to_subscript(
70+
expr, "ij,j->i"):
71+
return DoDistribute(ioperand=0)
72+
else:
73+
return DoNotDistribute()
74+
75+
x1 = pt.make_placeholder("x1", 4, np.float64)
76+
x2 = pt.make_placeholder("x2", 4, np.float64)
77+
A1 = pt.make_placeholder("A1", (10, 4), np.float64)
78+
A2 = pt.make_placeholder("A2", (10, 4), np.float64)
79+
y = (7*A1 + 8*pt.sin(A2)) @ (2*x1-3*x2)
80+
y_transformed = apply_distributive_property_to_einsums(y, how_to_distribute)
81+
print(y_transformed)
82+
assert y_transformed == (7 * (A1 @ (2*x1-3*x2)) + 8 * (pt.sin(A2) @ (2*x1-3*x2)))
83+
84+
85+
def test_apply_einsum_distributive_law_2():
86+
from pytato.transform.einsum_distributive_law import (
87+
EinsumDistributiveLawDescriptor,
88+
DoDistribute, DoNotDistribute,
89+
apply_distributive_property_to_einsums,
90+
)
91+
92+
def how_to_distribute(
93+
expr: pt.Einsum) -> EinsumDistributiveLawDescriptor:
94+
if (pt.analysis.is_einsum_similar_to_subscript(
95+
expr, "ij,j->i") and
96+
pt.utils.are_shape_components_equal(expr.args[1].shape[0],
97+
10)):
98+
return DoDistribute(ioperand=1)
99+
else:
100+
return DoNotDistribute()
101+
102+
x1 = pt.make_placeholder("x1", 4, np.float64)
103+
x2 = pt.make_placeholder("x2", 4, np.float64)
104+
A1 = pt.make_placeholder("A1", (10, 10), np.float64)
105+
A2 = pt.make_placeholder("A2", (10, 10), np.float64)
106+
B = pt.make_placeholder("B", (10, 4), np.float64)
107+
y = (7*A1 + 8*A2) @ (2*(B@x1)-3*(B@x2))
108+
y_transformed = apply_distributive_property_to_einsums(y, how_to_distribute)
109+
110+
assert y_transformed == (2 * ((7*A1 + 8*A2) @ (B@x1))
111+
- 3 * ((7*A1 + 8*A2) @ (B@x2)))
112+
113+
114+
if __name__ == "__main__":
115+
if len(sys.argv) > 1:
116+
exec(sys.argv[1])
117+
else:
118+
from pytest import main
119+
main([__file__])
120+
121+
# vim: fdm=marker

0 commit comments

Comments
 (0)