Skip to content

Commit

Permalink
Merge pull request #1309 from j2kun:analysis-visit-external-call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720257505
  • Loading branch information
copybara-github committed Jan 27, 2025
2 parents 28a0023 + e11160a commit b1436b4
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 6 deletions.
14 changes: 14 additions & 0 deletions lib/Analysis/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Utils",
hdrs = ["Utils.h"],
deps = [
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:Support",
],
)
3 changes: 2 additions & 1 deletion lib/Analysis/DimensionAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ cc_library(
srcs = ["DimensionAnalysis.cpp"],
hdrs = ["DimensionAnalysis.h"],
deps = [
"@heir//lib/Analysis:Utils",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)
17 changes: 15 additions & 2 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"

#include <algorithm>
#include <functional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -80,6 +84,15 @@ LogicalResult DimensionAnalysis::visitOperation(
return success();
}

void DimensionAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const DimensionLattice *> argumentLattices,
ArrayRef<DimensionLattice *> resultLattices) {
auto callback = std::bind(&DimensionAnalysis::propagateIfChangedWrapper, this,
std::placeholders::_1, std::placeholders::_2);
::mlir::heir::visitExternalCall<DimensionState, DimensionLattice>(
call, argumentLattices, resultLattices, callback);
}

int getDimension(Value value, DataFlowSolver *solver) {
auto *lattice = solver->lookupState<DimensionLattice>(value);
if (!lattice) {
Expand Down
9 changes: 9 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DimensionState {
assert(isInitialized());
return dimension.value();
}
DimensionType get() const { return getDimension(); }

bool operator==(const DimensionState &rhs) const {
return dimension == rhs.dimension;
Expand Down Expand Up @@ -78,6 +79,14 @@ class DimensionAnalysis
LogicalResult visitOperation(Operation *op,
ArrayRef<const DimensionLattice *> operands,
ArrayRef<DimensionLattice *> results) override;

void visitExternalCall(CallOpInterface call,
ArrayRef<const DimensionLattice *> argumentLattices,
ArrayRef<DimensionLattice *> resultLattices) override;

void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
propagateIfChanged(state, changed);
}
};

// this function will assert false when Lattice does not exist or not
Expand Down
2 changes: 2 additions & 0 deletions lib/Analysis/LevelAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ cc_library(
srcs = ["LevelAnalysis.cpp"],
hdrs = ["LevelAnalysis.h"],
deps = [
"@heir//lib/Analysis:Utils",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
Expand Down
14 changes: 13 additions & 1 deletion lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"

#include <algorithm>
#include <functional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
Expand All @@ -14,7 +16,8 @@
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -77,6 +80,15 @@ LogicalResult LevelAnalysis::visitOperation(
return success();
}

void LevelAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const LevelLattice *> argumentLattices,
ArrayRef<LevelLattice *> resultLattices) {
auto callback = std::bind(&LevelAnalysis::propagateIfChangedWrapper, this,
std::placeholders::_1, std::placeholders::_2);
::mlir::heir::visitExternalCall<LevelState, LevelLattice>(
call, argumentLattices, resultLattices, callback);
}

static int getMaxLevel(Operation *top, DataFlowSolver *solver) {
auto maxLevel = 0;
top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) {
Expand Down
9 changes: 9 additions & 0 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LevelState {
assert(isInitialized());
return level.value();
}
LevelType get() const { return getLevel(); }

bool operator==(const LevelState &rhs) const { return level == rhs.level; }

Expand Down Expand Up @@ -82,6 +83,14 @@ class LevelAnalysis
LogicalResult visitOperation(Operation *op,
ArrayRef<const LevelLattice *> operands,
ArrayRef<LevelLattice *> results) override;

void visitExternalCall(CallOpInterface call,
ArrayRef<const LevelLattice *> argumentLattices,
ArrayRef<LevelLattice *> resultLattices) override;

void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
propagateIfChanged(state, changed);
}
};

void annotateLevel(Operation *top, DataFlowSolver *solver);
Expand Down
2 changes: 2 additions & 0 deletions lib/Analysis/MulResultAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ cc_library(
srcs = ["MulResultAnalysis.cpp"],
hdrs = ["MulResultAnalysis.h"],
deps = [
"@heir//lib/Analysis:Utils",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
Expand Down
16 changes: 14 additions & 2 deletions lib/Analysis/MulResultAnalysis/MulResultAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "lib/Analysis/MulResultAnalysis/MulResultAnalysis.h"

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include <functional>

#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
Expand All @@ -9,7 +11,8 @@
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -74,5 +77,14 @@ LogicalResult MulResultAnalysis::visitOperation(
return success();
}

void MulResultAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const MulResultLattice *> argumentLattices,
ArrayRef<MulResultLattice *> resultLattices) {
auto callback = std::bind(&MulResultAnalysis::propagateIfChangedWrapper, this,
std::placeholders::_1, std::placeholders::_2);
::mlir::heir::visitExternalCall<MulResultState, MulResultLattice>(
call, argumentLattices, resultLattices, callback);
}

} // namespace heir
} // namespace mlir
8 changes: 8 additions & 0 deletions lib/Analysis/MulResultAnalysis/MulResultAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class MulResultAnalysis
LogicalResult visitOperation(Operation *op,
ArrayRef<const MulResultLattice *> operands,
ArrayRef<MulResultLattice *> results) override;

void visitExternalCall(CallOpInterface call,
ArrayRef<const MulResultLattice *> argumentLattices,
ArrayRef<MulResultLattice *> resultLattices) override;

void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
propagateIfChanged(state, changed);
}
};

} // namespace heir
Expand Down
2 changes: 2 additions & 0 deletions lib/Analysis/SecretnessAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ cc_library(
srcs = ["SecretnessAnalysis.cpp"],
hdrs = ["SecretnessAnalysis.h"],
deps = [
"@heir//lib/Analysis:Utils",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
15 changes: 15 additions & 0 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"

#include <algorithm>
#include <functional>
#include <string>

#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Secret/IR/SecretDialect.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
Expand All @@ -12,6 +14,7 @@
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -108,6 +111,15 @@ LogicalResult SecretnessAnalysis::visitOperation(
return mlir::success();
}

void SecretnessAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const SecretnessLattice *> argumentLattices,
ArrayRef<SecretnessLattice *> resultLattices) {
auto callback = std::bind(&SecretnessAnalysis::propagateIfChangedWrapper,
this, std::placeholders::_1, std::placeholders::_2);
::mlir::heir::visitExternalCall<Secretness, SecretnessLattice>(
call, argumentLattices, resultLattices, callback);
}

void annotateSecretness(Operation *top, DataFlowSolver *solver, bool verbose) {
// Attribute "Printing" Helper
auto getAttribute =
Expand All @@ -132,6 +144,9 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver, bool verbose) {
top->walk([&](Operation *op) {
// Custom Handling for `func.func`, which uses special attributes
if (auto func = llvm::dyn_cast<func::FuncOp>(op)) {
if (func.isDeclaration()) {
return;
}
// Arguments
for (unsigned i = 0; i < func.getNumArguments(); ++i) {
auto arg = func.getArgument(i);
Expand Down
9 changes: 9 additions & 0 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Secretness {
assert(isInitialized());
return *secretness;
}
const bool &get() const { return getSecretness(); }
void setSecretness(bool value) { secretness = value; }

// Check if the Secretness state is initialized. It can be uninitialized if
Expand Down Expand Up @@ -104,6 +105,14 @@ class SecretnessAnalysis
LogicalResult visitOperation(Operation *operation,
ArrayRef<const SecretnessLattice *> operands,
ArrayRef<SecretnessLattice *> results) override;

void visitExternalCall(CallOpInterface call,
ArrayRef<const SecretnessLattice *> argumentLattices,
ArrayRef<SecretnessLattice *> resultLattices) override;

void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
propagateIfChanged(state, changed);
}
};

/**
Expand Down
39 changes: 39 additions & 0 deletions lib/Analysis/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef LIB_ANALYSIS_UTILS_H_
#define LIB_ANALYSIS_UTILS_H_

#include <functional>

#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

// A generalized version of visitExternalCall that joins all arguments to the
// func.call op, and then propagates this joined value to all results. This is
// useful for debug functions where there is a single operand that is not
// changed by the external call, but can also be useful for some analyses like
// secretness, where the result of the call is secret if any operand is secret.
template <typename StateT, typename LatticeT>
void visitExternalCall(CallOpInterface call,
ArrayRef<const LatticeT *> argumentLattices,
ArrayRef<LatticeT *> resultLattices,
const std::function<void(AnalysisState *, ChangeResult)>
&propagateIfChanged) {
StateT resultState = StateT();

for (const LatticeT *operand : argumentLattices) {
const StateT operandState = operand->getValue();
resultState = StateT::join(resultState, operandState);
}

for (LatticeT *result : resultLattices) {
propagateIfChanged(result, result->join(resultState));
}
}

} // namespace heir
} // namespace mlir

#endif // LIB_ANALYSIS_UTILS_H_
10 changes: 10 additions & 0 deletions tests/Transforms/annotate_secretness/annotate_secretness.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,13 @@ func.func @return_secretness(%s: i32 {secret.secret}, %p: i32) -> (i32) {
//CHECK-NEXT: return {secret.secret} [[R:%.*]] : i32
return %0 : i32
}

func.func private @callee(!secret.secret<i32>) -> !secret.secret<i32>
// CHECK: @func_call
// CHECK-SAME: ([[S:%.*]]: [[ST:.*]])
func.func @func_call(%s: !secret.secret<i32>) {
// CHECK: call @callee
// CHECK-SAME: {secret.secret}
%1 = func.call @callee(%s) : (!secret.secret<i32>) -> !secret.secret<i32>
func.return
}

0 comments on commit b1436b4

Please sign in to comment.