Skip to content

Commit b1436b4

Browse files
Merge pull request #1309 from j2kun:analysis-visit-external-call
PiperOrigin-RevId: 720257505
2 parents 28a0023 + e11160a commit b1436b4

File tree

15 files changed

+163
-6
lines changed

15 files changed

+163
-6
lines changed

lib/Analysis/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package(
2+
default_applicable_licenses = ["@heir//:license"],
3+
default_visibility = ["//visibility:public"],
4+
)
5+
6+
cc_library(
7+
name = "Utils",
8+
hdrs = ["Utils.h"],
9+
deps = [
10+
"@llvm-project//mlir:Analysis",
11+
"@llvm-project//mlir:CallOpInterfaces",
12+
"@llvm-project//mlir:Support",
13+
],
14+
)

lib/Analysis/DimensionAnalysis/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ cc_library(
88
srcs = ["DimensionAnalysis.cpp"],
99
hdrs = ["DimensionAnalysis.h"],
1010
deps = [
11+
"@heir//lib/Analysis:Utils",
1112
"@heir//lib/Analysis/SecretnessAnalysis",
1213
"@heir//lib/Dialect/Mgmt/IR:Dialect",
1314
"@heir//lib/Dialect/Secret/IR:Dialect",
1415
"@llvm-project//llvm:Support",
1516
"@llvm-project//mlir:Analysis",
1617
"@llvm-project//mlir:ArithDialect",
18+
"@llvm-project//mlir:FunctionInterfaces",
1719
"@llvm-project//mlir:IR",
1820
"@llvm-project//mlir:Support",
19-
"@llvm-project//mlir:TensorDialect",
2021
],
2122
)

lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"
22

3+
#include <algorithm>
4+
#include <functional>
5+
36
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
7+
#include "lib/Analysis/Utils.h"
48
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
59
#include "lib/Dialect/Secret/IR/SecretOps.h"
610
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
711
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
812
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
9-
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
1013
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1114
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1215
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
13-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
16+
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
17+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1418

1519
namespace mlir {
1620
namespace heir {
@@ -80,6 +84,15 @@ LogicalResult DimensionAnalysis::visitOperation(
8084
return success();
8185
}
8286

87+
void DimensionAnalysis::visitExternalCall(
88+
CallOpInterface call, ArrayRef<const DimensionLattice *> argumentLattices,
89+
ArrayRef<DimensionLattice *> resultLattices) {
90+
auto callback = std::bind(&DimensionAnalysis::propagateIfChangedWrapper, this,
91+
std::placeholders::_1, std::placeholders::_2);
92+
::mlir::heir::visitExternalCall<DimensionState, DimensionLattice>(
93+
call, argumentLattices, resultLattices, callback);
94+
}
95+
8396
int getDimension(Value value, DataFlowSolver *solver) {
8497
auto *lattice = solver->lookupState<DimensionLattice>(value);
8598
if (!lattice) {

lib/Analysis/DimensionAnalysis/DimensionAnalysis.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class DimensionState {
2626
assert(isInitialized());
2727
return dimension.value();
2828
}
29+
DimensionType get() const { return getDimension(); }
2930

3031
bool operator==(const DimensionState &rhs) const {
3132
return dimension == rhs.dimension;
@@ -78,6 +79,14 @@ class DimensionAnalysis
7879
LogicalResult visitOperation(Operation *op,
7980
ArrayRef<const DimensionLattice *> operands,
8081
ArrayRef<DimensionLattice *> results) override;
82+
83+
void visitExternalCall(CallOpInterface call,
84+
ArrayRef<const DimensionLattice *> argumentLattices,
85+
ArrayRef<DimensionLattice *> resultLattices) override;
86+
87+
void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
88+
propagateIfChanged(state, changed);
89+
}
8190
};
8291

8392
// this function will assert false when Lattice does not exist or not

lib/Analysis/LevelAnalysis/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ cc_library(
88
srcs = ["LevelAnalysis.cpp"],
99
hdrs = ["LevelAnalysis.h"],
1010
deps = [
11+
"@heir//lib/Analysis:Utils",
1112
"@heir//lib/Analysis/SecretnessAnalysis",
1213
"@heir//lib/Dialect/Mgmt/IR:Dialect",
1314
"@heir//lib/Dialect/Secret/IR:Dialect",
1415
"@llvm-project//llvm:Support",
1516
"@llvm-project//mlir:Analysis",
1617
"@llvm-project//mlir:ArithDialect",
18+
"@llvm-project//mlir:FunctionInterfaces",
1719
"@llvm-project//mlir:IR",
1820
"@llvm-project//mlir:Support",
1921
"@llvm-project//mlir:TensorDialect",

lib/Analysis/LevelAnalysis/LevelAnalysis.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"
22

33
#include <algorithm>
4+
#include <functional>
45

56
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
7+
#include "lib/Analysis/Utils.h"
68
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
79
#include "lib/Dialect/Secret/IR/SecretOps.h"
810
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
@@ -14,7 +16,8 @@
1416
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1517
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1618
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
17-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
19+
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
20+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1821

1922
namespace mlir {
2023
namespace heir {
@@ -77,6 +80,15 @@ LogicalResult LevelAnalysis::visitOperation(
7780
return success();
7881
}
7982

83+
void LevelAnalysis::visitExternalCall(
84+
CallOpInterface call, ArrayRef<const LevelLattice *> argumentLattices,
85+
ArrayRef<LevelLattice *> resultLattices) {
86+
auto callback = std::bind(&LevelAnalysis::propagateIfChangedWrapper, this,
87+
std::placeholders::_1, std::placeholders::_2);
88+
::mlir::heir::visitExternalCall<LevelState, LevelLattice>(
89+
call, argumentLattices, resultLattices, callback);
90+
}
91+
8092
static int getMaxLevel(Operation *top, DataFlowSolver *solver) {
8193
auto maxLevel = 0;
8294
top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) {

lib/Analysis/LevelAnalysis/LevelAnalysis.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class LevelState {
3333
assert(isInitialized());
3434
return level.value();
3535
}
36+
LevelType get() const { return getLevel(); }
3637

3738
bool operator==(const LevelState &rhs) const { return level == rhs.level; }
3839

@@ -82,6 +83,14 @@ class LevelAnalysis
8283
LogicalResult visitOperation(Operation *op,
8384
ArrayRef<const LevelLattice *> operands,
8485
ArrayRef<LevelLattice *> results) override;
86+
87+
void visitExternalCall(CallOpInterface call,
88+
ArrayRef<const LevelLattice *> argumentLattices,
89+
ArrayRef<LevelLattice *> resultLattices) override;
90+
91+
void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
92+
propagateIfChanged(state, changed);
93+
}
8594
};
8695

8796
void annotateLevel(Operation *top, DataFlowSolver *solver);

lib/Analysis/MulResultAnalysis/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ cc_library(
88
srcs = ["MulResultAnalysis.cpp"],
99
hdrs = ["MulResultAnalysis.h"],
1010
deps = [
11+
"@heir//lib/Analysis:Utils",
1112
"@heir//lib/Analysis/SecretnessAnalysis",
1213
"@heir//lib/Dialect/Secret/IR:Dialect",
1314
"@llvm-project//llvm:Support",
1415
"@llvm-project//mlir:Analysis",
1516
"@llvm-project//mlir:ArithDialect",
17+
"@llvm-project//mlir:FunctionInterfaces",
1618
"@llvm-project//mlir:IR",
1719
"@llvm-project//mlir:Support",
1820
"@llvm-project//mlir:TensorDialect",

lib/Analysis/MulResultAnalysis/MulResultAnalysis.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "lib/Analysis/MulResultAnalysis/MulResultAnalysis.h"
22

3-
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
3+
#include <functional>
4+
5+
#include "lib/Analysis/Utils.h"
46
#include "lib/Dialect/Secret/IR/SecretOps.h"
57
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
68
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
@@ -9,7 +11,8 @@
911
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1012
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1113
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
12-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
14+
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
15+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1316

1417
namespace mlir {
1518
namespace heir {
@@ -74,5 +77,14 @@ LogicalResult MulResultAnalysis::visitOperation(
7477
return success();
7578
}
7679

80+
void MulResultAnalysis::visitExternalCall(
81+
CallOpInterface call, ArrayRef<const MulResultLattice *> argumentLattices,
82+
ArrayRef<MulResultLattice *> resultLattices) {
83+
auto callback = std::bind(&MulResultAnalysis::propagateIfChangedWrapper, this,
84+
std::placeholders::_1, std::placeholders::_2);
85+
::mlir::heir::visitExternalCall<MulResultState, MulResultLattice>(
86+
call, argumentLattices, resultLattices, callback);
87+
}
88+
7789
} // namespace heir
7890
} // namespace mlir

lib/Analysis/MulResultAnalysis/MulResultAnalysis.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ class MulResultAnalysis
8484
LogicalResult visitOperation(Operation *op,
8585
ArrayRef<const MulResultLattice *> operands,
8686
ArrayRef<MulResultLattice *> results) override;
87+
88+
void visitExternalCall(CallOpInterface call,
89+
ArrayRef<const MulResultLattice *> argumentLattices,
90+
ArrayRef<MulResultLattice *> resultLattices) override;
91+
92+
void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
93+
propagateIfChanged(state, changed);
94+
}
8795
};
8896

8997
} // namespace heir

lib/Analysis/SecretnessAnalysis/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ cc_library(
99
srcs = ["SecretnessAnalysis.cpp"],
1010
hdrs = ["SecretnessAnalysis.h"],
1111
deps = [
12+
"@heir//lib/Analysis:Utils",
1213
"@heir//lib/Dialect/Secret/IR:Dialect",
1314
"@llvm-project//mlir:Analysis",
15+
"@llvm-project//mlir:CallOpInterfaces",
1416
"@llvm-project//mlir:FuncDialect",
1517
"@llvm-project//mlir:IR",
1618
"@llvm-project//mlir:Support",

lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
22

33
#include <algorithm>
4+
#include <functional>
45
#include <string>
56

7+
#include "lib/Analysis/Utils.h"
68
#include "lib/Dialect/Secret/IR/SecretDialect.h"
79
#include "lib/Dialect/Secret/IR/SecretOps.h"
810
#include "lib/Dialect/Secret/IR/SecretTypes.h"
@@ -12,6 +14,7 @@
1214
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1315
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1416
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
17+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
1518
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1619

1720
namespace mlir {
@@ -108,6 +111,15 @@ LogicalResult SecretnessAnalysis::visitOperation(
108111
return mlir::success();
109112
}
110113

114+
void SecretnessAnalysis::visitExternalCall(
115+
CallOpInterface call, ArrayRef<const SecretnessLattice *> argumentLattices,
116+
ArrayRef<SecretnessLattice *> resultLattices) {
117+
auto callback = std::bind(&SecretnessAnalysis::propagateIfChangedWrapper,
118+
this, std::placeholders::_1, std::placeholders::_2);
119+
::mlir::heir::visitExternalCall<Secretness, SecretnessLattice>(
120+
call, argumentLattices, resultLattices, callback);
121+
}
122+
111123
void annotateSecretness(Operation *top, DataFlowSolver *solver, bool verbose) {
112124
// Attribute "Printing" Helper
113125
auto getAttribute =
@@ -132,6 +144,9 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver, bool verbose) {
132144
top->walk([&](Operation *op) {
133145
// Custom Handling for `func.func`, which uses special attributes
134146
if (auto func = llvm::dyn_cast<func::FuncOp>(op)) {
147+
if (func.isDeclaration()) {
148+
return;
149+
}
135150
// Arguments
136151
for (unsigned i = 0; i < func.getNumArguments(); ++i) {
137152
auto arg = func.getArgument(i);

lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Secretness {
2525
assert(isInitialized());
2626
return *secretness;
2727
}
28+
const bool &get() const { return getSecretness(); }
2829
void setSecretness(bool value) { secretness = value; }
2930

3031
// Check if the Secretness state is initialized. It can be uninitialized if
@@ -104,6 +105,14 @@ class SecretnessAnalysis
104105
LogicalResult visitOperation(Operation *operation,
105106
ArrayRef<const SecretnessLattice *> operands,
106107
ArrayRef<SecretnessLattice *> results) override;
108+
109+
void visitExternalCall(CallOpInterface call,
110+
ArrayRef<const SecretnessLattice *> argumentLattices,
111+
ArrayRef<SecretnessLattice *> resultLattices) override;
112+
113+
void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
114+
propagateIfChanged(state, changed);
115+
}
107116
};
108117

109118
/**

lib/Analysis/Utils.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef LIB_ANALYSIS_UTILS_H_
2+
#define LIB_ANALYSIS_UTILS_H_
3+
4+
#include <functional>
5+
6+
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
7+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
8+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
9+
10+
namespace mlir {
11+
namespace heir {
12+
13+
// A generalized version of visitExternalCall that joins all arguments to the
14+
// func.call op, and then propagates this joined value to all results. This is
15+
// useful for debug functions where there is a single operand that is not
16+
// changed by the external call, but can also be useful for some analyses like
17+
// secretness, where the result of the call is secret if any operand is secret.
18+
template <typename StateT, typename LatticeT>
19+
void visitExternalCall(CallOpInterface call,
20+
ArrayRef<const LatticeT *> argumentLattices,
21+
ArrayRef<LatticeT *> resultLattices,
22+
const std::function<void(AnalysisState *, ChangeResult)>
23+
&propagateIfChanged) {
24+
StateT resultState = StateT();
25+
26+
for (const LatticeT *operand : argumentLattices) {
27+
const StateT operandState = operand->getValue();
28+
resultState = StateT::join(resultState, operandState);
29+
}
30+
31+
for (LatticeT *result : resultLattices) {
32+
propagateIfChanged(result, result->join(resultState));
33+
}
34+
}
35+
36+
} // namespace heir
37+
} // namespace mlir
38+
39+
#endif // LIB_ANALYSIS_UTILS_H_

tests/Transforms/annotate_secretness/annotate_secretness.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,13 @@ func.func @return_secretness(%s: i32 {secret.secret}, %p: i32) -> (i32) {
4343
//CHECK-NEXT: return {secret.secret} [[R:%.*]] : i32
4444
return %0 : i32
4545
}
46+
47+
func.func private @callee(!secret.secret<i32>) -> !secret.secret<i32>
48+
// CHECK: @func_call
49+
// CHECK-SAME: ([[S:%.*]]: [[ST:.*]])
50+
func.func @func_call(%s: !secret.secret<i32>) {
51+
// CHECK: call @callee
52+
// CHECK-SAME: {secret.secret}
53+
%1 = func.call @callee(%s) : (!secret.secret<i32>) -> !secret.secret<i32>
54+
func.return
55+
}

0 commit comments

Comments
 (0)