Skip to content

Commit 9ac2980

Browse files
committed
add support for pymbolic.EqualityMapper
1 parent 8016656 commit 9ac2980

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pytato/scalar_expr.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
DistributeMapperBase)
4242
from pymbolic.mapper.stringifier import (StringifyMapper as
4343
StringifyMapperBase)
44+
from pymbolic.mapper.equality import (EqualityMapper as
45+
EqualityMapperBase)
4446
from pymbolic.mapper import CombineMapper as CombineMapperBase
4547
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
4648
import pymbolic.primitives as prim
@@ -178,6 +180,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
178180
bounds_expr = "{" + bounds_expr + "}"
179181
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
180182

183+
184+
class EqualityMapper(EqualityMapperBase):
185+
def map_reduce(self, expr: Reduce, other: Reduce) -> bool:
186+
return (
187+
len(expr.bounds) == len(other.bounds)
188+
and all(k == other_k
189+
and self.rec(lb, other_lb) and self.rec(ub, other_ub)
190+
for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip(
191+
sorted(expr.bounds.items()),
192+
sorted(other.bounds.items())))
193+
and expr.op == other.op
194+
and self.rec(expr.inner_expr, other.inner_expr)
195+
)
196+
181197
# }}}
182198

183199

@@ -234,6 +250,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
234250
# {{{ custom scalar expression nodes
235251

236252
class ExpressionBase(prim.Expression):
253+
def make_equality_mapper(self) -> EqualityMapper:
254+
return EqualityMapper()
255+
237256
def make_stringifier(self, originating_stringifier: Any = None) -> str:
238257
return StringifyMapper()
239258

test/test_pytato.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def test_userscollector():
355355

356356

357357
def test_asciidag():
358+
pytest.importorskip("asciidag")
359+
358360
n = pt.make_size_param("n")
359361
array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
360362
stack = pt.stack([array, 2*array, array + 6])

0 commit comments

Comments
 (0)