Skip to content

Relax non-negative-integer Pow to ipow (fix symbolic-exponent codegen)#2426

Open
ThrudPrimrose wants to merge 6 commits into
mainfrom
relax_ipow
Open

Relax non-negative-integer Pow to ipow (fix symbolic-exponent codegen)#2426
ThrudPrimrose wants to merge 6 commits into
mainfrom
relax_ipow

Conversation

@ThrudPrimrose

Copy link
Copy Markdown
Collaborator

Here I provide a pass relaxes pow to ipow if the exponent is nonnegative. It analyzes loop ranges / map ranges to check the bound of the expression, as appearing in stockham_fft. This way, we can compile stockham_fft.

DaCe's symbolic C++ printer emits dace::math::pow (libm, double) for a
non-constant exponent, which is illegal where a C++ integer is required -- an
array size, a subscript, or a loop bound. For example the radix decomposition
R**(K-i-1) used as a stride in stockham_fft fails to compile. dace::math::ipow
(repeated multiply, exact integer) is the correct lowering iff the exponent is
a non-negative integer (its C++ exponent is unsigned, so a negative one is
catastrophic).

- Add an ipow SymPy Function (dace/symbolic.py) lowering to dace::math::ipow;
  registered so it round-trips through parsing and serialization.
- Add the RelaxIntegerPowers pass rewriting Pow -> ipow for non-negative
  integer exponents: constants, integer-valued float literals, and symbolic
  exponents proven >= 0 by interval analysis over the enclosing loop / map
  iterator ranges. A scoped recursive descent binds each iterator (and each
  SDFG's declared symbol signs) as it enters loops, maps and nested SDFGs and
  drops them on exit; the proof reuses equalize_symbol and the assuming/ask
  idiom. Wired last in SIMPLIFY_PASSES.
- Teach infer_symbols_from_datadescriptor to read ipow as Pow when solving a
  shape for its symbols (ipow is opaque to the solver otherwise).
- Unskip the stockham_fft CPU test; add tests/passes/relax_integer_powers_test.py.
_SerializedSymbolicParser._functions (the serialize_symbolic/
deserialize_symbolic path used for SDFG properties like an Array's
total_size) lacked ipow, so a serialized 64*ipow(P, 2) deserialized to an
opaque Function('ipow') that never folds; int(evaluate(total_size, {P: n}))
then raised "Cannot convert symbols to int" (mpi subarrays_test BlockGather,
whose return shape is 8*P x 8*P). Register the real class so it round-trips
and folds.
…ove pass pre-codegen

Root cause of the stockham_fft heap corruption: ``dace::math::ipow`` seeded its
accumulator at ``a`` and looped ``b - 1`` times, so ``ipow(a, 0)`` returned ``a``
instead of ``1``. Every zero-exponent size/stride/bound was corrupted (e.g.
``ipow(R, 0)`` -> ``R``), doubling the innermost tile-transpose loop and writing one
element past the output array. Seed at ``1`` and loop ``b`` times so the base case is
correct; pass the scalar exponent (and base) by value.

Also close the gaps that left an integer-context power as a ``double`` ``dace::math::pow``:
- RelaxIntegerPowers now visits loop bounds/conditions, interstate-edge conditions +
  assignments, and nested-SDFG symbol mappings -- not just descriptors and memlet
  subsets. The scope analysis already proved these exponents non-negative; they were
  simply never offered to it.
- cppunparse qualifies an ``ipow`` call as ``dace::math::ipow`` (the symbolic printer
  already does for descriptor expressions; loop bounds/interstate edges unparse through
  the AST path instead).
- Move RelaxIntegerPowers out of SIMPLIFY_PASSES to a pre-codegen pass. A size is just a
  more complex expression that should not perturb simplification; keeping powers as
  ``Pow`` through simplify also lets SymPy fold ``R**i * R**(K-i-1) -> R**(K-1)`` before
  the opaque ``ipow`` freezes the form.
The single ipow template seeded the accumulator with T(1), but vector value
types (half4/half8, and the CPU vector types) have no scalar constructor, so
T(1) fails to compile -- regressing the CUDA half-vector gelu path that raises
a value to a constant integer power (a**3 -> ipow(a, 3)).

Split into two SFINAE overloads keyed on is_constructible<T, int>: scalar types
seed at T(1) (keeping the a**0 == 1 base case), vector types seed at a and loop
from 1. Vector ipow is only ever emitted with a compile-time exponent >= 1 (the
constant-power path writes a literal 1 for exponent 0), so the base case only
matters for scalars.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant