diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 3ba0d9a5a..1d3305570 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -78,6 +78,9 @@ from pytato.transform import Mapper +from loopy.symbolic import IdentityMapper as LoopyIdentityMapper +from pymbolic.mapper.subst_applier import SubstitutionApplier + # set in doc/conf.py if getattr(sys, "_BUILDING_SPHINX_DOCS", False): # Avoid import unless building docs to avoid creating a hard @@ -108,10 +111,16 @@ """ -def loopy_substitute(expression: Any, variable_assignments: Mapping[str, Any]) -> Any: - from loopy.symbolic import SubstitutionMapper - from pymbolic.mapper.substitutor import make_subst_func +# type-ignore-reason: superclasses have no type information +class LoopySubstitutionApplier( + SubstitutionApplier, LoopyIdentityMapper): # type: ignore + def get_cache_key(self, expr: ScalarExpression, + current_substs: Dict[ScalarExpression, ScalarExpression])\ + -> Tuple[Any, ScalarExpression, Any]: + return (type(expr), expr, tuple(sorted(current_substs.items()))) + +def loopy_substitute(expression: Any, variable_assignments: Mapping[str, Any]) -> Any: # {{{ early exit for identity substitution if all(isinstance(v, prim.Variable) and v.name == k @@ -121,7 +130,7 @@ def loopy_substitute(expression: Any, variable_assignments: Mapping[str, Any]) - # }}} - return SubstitutionMapper(make_subst_func(variable_assignments))(expression) + return prim.Substitution(expression, *zip(*variable_assignments.items())) # SymbolicIndex and ShapeType are semantically distinct but identical at the @@ -876,7 +885,8 @@ def add_store(name: str, expr: Array, result: ImplementedResult, for d in range(expr.ndim)) indices = tuple(prim.Variable(iname) for iname in inames) loopy_expr_context = PersistentExpressionContext(state) - loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) + loopy_expr = LoopySubstitutionApplier()( + result.to_loopy_expression(indices, loopy_expr_context)) # Make the instruction from loopy.kernel.instruction import make_assignment