Skip to content

Fix RefineCallOpPattern in StablehloRefineShapes to preserve side-effecting callees#2934

Open
christopherbate wants to merge 2 commits intoopenxla:mainfrom
christopherbate:fix-refine-shapes-side-effects
Open

Fix RefineCallOpPattern in StablehloRefineShapes to preserve side-effecting callees#2934
christopherbate wants to merge 2 commits intoopenxla:mainfrom
christopherbate:fix-refine-shapes-side-effects

Conversation

@christopherbate
Copy link
Copy Markdown
Contributor

The constant-function replacement in RefineCallOpPattern did not check
whether the callee contains side-effecting operations. This caused calls
to functions with side effects (e.g. custom_call with
has_side_effect=true) to be silently erased. Guard the replacement with
a walk that checks for MemoryEffectOpInterface declarations.

…ecting callees

The constant-function replacement in RefineCallOpPattern did not check
whether the callee contains side-effecting operations. This caused calls
to functions with side effects (e.g. custom_call with
has_side_effect=true) to be silently erased. Guard the replacement with
a walk that checks for MemoryEffectOpInterface declarations.
@christopherbate christopherbate requested a review from GleasonK April 9, 2026 00:01
for (auto constAttr : constantAttrs.value()) {
constants.push_back(
ConstantOp::create(rewriter, op.getLoc(), constAttr));
auto sideEffectResult = callee.walk([](Operation* nestedOp) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - bool hasSideEffects = callee.walk(...).wasInterrupted(); to make it clear that interruption indicates side effect in the same line

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm not sure if this is a complete solution. If there's a call op with a nested call op I don't think this catches its side effects. We had this problem in simplification as well and made a similar method:

bool hasAnyDeclaredSideEffects(Operation* op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
// Return true if the op explicitly declares any memory effects of its own.
if (!memInterface.hasNoEffect()) return true;
// The op has no direct memory effects. Return false if it has no recursive
// memory effects, either.
if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) return false;
}
// The op doesn't declare any side effects of its own, but its regions could
// still contain ops that do declare side effects. Recursively check them.
for (Region& region : op->getRegions()) {
for (Operation& nestedOp : region.getOps()) {
if (hasAnyDeclaredSideEffects(&nestedOp)) return true;
}
}
return false;
}

A complete solution would likely require an analysis pass that maps call->side effects and does a full pass to calculate transitive side effects

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let me see if there's a way I can address this by computing it up front... I already didn't like the walk because the result should be cached

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants