|
41 | 41 | DistributeMapperBase)
|
42 | 42 | from pymbolic.mapper.stringifier import (StringifyMapper as
|
43 | 43 | StringifyMapperBase)
|
| 44 | +from pymbolic.mapper.equality import (EqualityMapper as |
| 45 | + EqualityMapperBase) |
44 | 46 | from pymbolic.mapper import CombineMapper as CombineMapperBase
|
45 | 47 | from pymbolic.mapper.collector import TermCollector as TermCollectorBase
|
46 | 48 | import pymbolic.primitives as prim
|
@@ -178,6 +180,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
|
178 | 180 | bounds_expr = "{" + bounds_expr + "}"
|
179 | 181 | return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
|
180 | 182 |
|
| 183 | + |
| 184 | +class EqualityMapper(EqualityMapperBase): |
| 185 | + def map_reduce(self, expr: Reduce, other: Reduce) -> bool: |
| 186 | + return ( |
| 187 | + len(expr.bounds) == len(other.bounds) |
| 188 | + and all(k == other_k |
| 189 | + and self.rec(lb, other_lb) and self.rec(ub, other_ub) |
| 190 | + for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip( |
| 191 | + sorted(expr.bounds.items()), |
| 192 | + sorted(other.bounds.items()))) |
| 193 | + and expr.op == other.op |
| 194 | + and self.rec(expr.inner_expr, other.inner_expr) |
| 195 | + ) |
| 196 | + |
181 | 197 | # }}}
|
182 | 198 |
|
183 | 199 |
|
@@ -234,6 +250,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
|
234 | 250 | # {{{ custom scalar expression nodes
|
235 | 251 |
|
236 | 252 | class ExpressionBase(prim.Expression):
|
| 253 | + def make_equality_mapper(self) -> EqualityMapper: |
| 254 | + return EqualityMapper() |
| 255 | + |
237 | 256 | def make_stringifier(self, originating_stringifier: Any = None) -> str:
|
238 | 257 | return StringifyMapper()
|
239 | 258 |
|
|
0 commit comments