Open
Description
I'm encountering an issue when trying to integrate custom kernels—specifically those built with pallas—and SPMD operations (like xs.mark_sharding) into the scan function. The current design of scan relies on AOTAutograd to capture the forward and backward passes. Because of that, only standard PyTorch ATen operations are recognized in the captured graph.
pallas Kernels Excluded
- Because pallas kernels are neither standard ATen ops nor automatically recognized by AOTAutograd, they don’t appear in the traced graph. This makes them invisible to scan, even if the kernel code is correct and runs fine outside of a scan.
- One workaround is to wrap each pallas kernel in a custom ATen op. But that approach adds friction, requiring additional boilerplate to register the op so that PyTorch’s dispatch and AOTAutograd can see it.
SPMD Ops Unrecognized
- Similar to pallas kernels, calls such as
xs.mark_sharding(...)
do not appear in the captured graph. The AOTAutograd tracing step sees them as Python calls that do not translate into recognized ATen ops. - This prevents us from assigning SPMD partitioning attributes within a scan function, making it impossible to do sharding or other SPMD strategies inside the scanned layer.
Tracing annotations are skipped
xp.trace_me(...)
also annotates the LazyTensor IR and doesn't have a corresponding aten representation. As a result, the scanned layer won't have tracing annotations.
AOTAutograd vs. LazyTensor C++ approaches
- scan leverages AOTAutograd to partition the forward and backward passes, enabling advanced features like gradient checkpointing or user-defined partitioning strategies. However, this means the captured graph must be composed entirely of ATen ops recognized by AOTAutograd.
- On the other hand, the LazyTensor C++ backend (in the XLA stack) can sometimes capture more exotic or low-level operations by intercepting them at the IR level, but that path isn’t used by AOTAutograd-based flows. There’s a trade-off: AOTAutograd provides a pure-PyTorch approach for graph capture and transformation, yet it effectively filters out non-ATen ops.
Desired Behavior
- Ideally, we want scan to handle any operation recognized by the underlying XLA or LazyTensor stack, including pallas kernels and SPMD operations.
- If there is a route for AOTAutograd to allow extension ops—like a stable mechanism for capturing pallas kernels or calls like mark_sharding— that would solve the problem. Otherwise, a different capture/trace mechanism might be needed to allow these ops in a scanned function.
Discussion Points
- Can we extend AOTAutograd so it recognizes pallas kernels and SPMD ops (e.g., by whitelisting custom ops or hooking into the IR generation)?
- Do we need a custom ATen registration for each pallas kernel? If so, can we streamline that process or document it?
- Is there a recommended workaround to ensure xs.mark_sharding(...) is picked up in the graph capture?
- Are there plans for broader extensibility in the AOTAutograd stack, allowing custom or otherwise “non-ATen” ops?