Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ cc_library(
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,22 @@ module @refine_call_dimension_argument_not_integer {

// -----

// CHECK-LABEL: module @refine_call_side_effecting_callee
module @refine_call_side_effecting_callee {
func.func public @main() {
// CHECK: call @callee()
call @callee() : () -> ()
return
}
func.func private @callee() {
%0 = stablehlo.constant dense<0> : tensor<i32>
stablehlo.custom_call @side_effect(%0) {has_side_effect = true} : (tensor<i32>) -> ()
return
}
}

// -----

// CHECK-LABEL: func @refine_convert
func.func @refine_convert(%arg0 : tensor<4xf32>) -> tensor<?xi32> {
// CHECK: stablehlo.convert{{.*}} -> tensor<4xi32>
Expand Down
21 changes: 15 additions & 6 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -566,13 +567,21 @@ struct RefineCallOpPattern : public OpRewritePattern<func::CallOp> {
std::optional<SmallVector<DenseIntElementsAttr>> constantAttrs =
isConstantFunction(callee);
if (constantAttrs.has_value()) {
SmallVector<Value> constants;
for (auto constAttr : constantAttrs.value()) {
constants.push_back(
ConstantOp::create(rewriter, op.getLoc(), constAttr));
auto sideEffectResult = callee.walk([](Operation* nestedOp) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - bool hasSideEffects = callee.walk(...).wasInterrupted(); to make it clear that interruption indicates side effect in the same line

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm not sure if this is a complete solution. If there's a call op with a nested call op I don't think this catches its side effects. We had this problem in simplification as well and made a similar method:

bool hasAnyDeclaredSideEffects(Operation* op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
// Return true if the op explicitly declares any memory effects of its own.
if (!memInterface.hasNoEffect()) return true;
// The op has no direct memory effects. Return false if it has no recursive
// memory effects, either.
if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) return false;
}
// The op doesn't declare any side effects of its own, but its regions could
// still contain ops that do declare side effects. Recursively check them.
for (Region& region : op->getRegions()) {
for (Operation& nestedOp : region.getOps()) {
if (hasAnyDeclaredSideEffects(&nestedOp)) return true;
}
}
return false;
}

A complete solution would likely require an analysis pass that maps call->side effects and does a full pass to calculate transitive side effects

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let me see if there's a way I can address this by computing it up front... I already didn't like the walk because the result should be cached

auto effectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
if (effectInterface && !effectInterface.hasNoEffect())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (!sideEffectResult.wasInterrupted()) {
SmallVector<Value> constants;
for (auto constAttr : constantAttrs.value()) {
constants.push_back(
ConstantOp::create(rewriter, op.getLoc(), constAttr));
}
rewriter.replaceOp(op, constants);
return success();
}
rewriter.replaceOp(op, constants);
return success();
}
if (!refinementKey->getGlobalConstants().empty()) {
// Drop the global-constant arguments, but only if necessary, or else we
Expand Down
Loading