Skip to content

Commit f6472d4

Browse files
committed
Merge pull request #429 from asraa:test-pr
PiperOrigin-RevId: 604980250
2 parents d8874af + ee0dfe6 commit f6472d4

File tree

18 files changed

+772
-68
lines changed

18 files changed

+772
-68
lines changed

.github/actionlint.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
self-hosted-runner:
33
labels:
44
- ubuntu-20.04-16core
5+
- ubuntu-20.04-4core

.github/workflows/build_and_test.yml

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,42 @@ concurrency:
1111
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
1212
cancel-in-progress: true
1313
jobs:
14+
check-cache:
15+
runs-on:
16+
labels: ubuntu-20.04-4core
17+
outputs:
18+
runner: ${{ steps.runner.outputs.runner }}
19+
steps:
20+
- name: Check out repository code
21+
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
22+
23+
- name: Cache bazel build artifacts
24+
id: cache
25+
uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # [email protected]
26+
with:
27+
path: |
28+
~/.cache/bazel
29+
key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}-${{ hashFiles('bazel/import_llvm.bzl') }}
30+
restore-keys: |
31+
${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}-
32+
lookup-only: true
33+
- name: Select runner
34+
id: runner
35+
env:
36+
CACHE_HIT: ${{ steps.cache.outputs.cache-hit == 'true' }}
37+
run: |
38+
set -euo pipefail
39+
40+
if [[ "${CACHE_HIT}" == "true" ]]; then
41+
echo "runner=ubuntu-20.04-4core" >> "$GITHUB_OUTPUT"
42+
else
43+
echo "runner=ubuntu-20.04-16core" >> "$GITHUB_OUTPUT"
44+
fi
45+
1446
build-and-test:
47+
needs: check-cache
1548
runs-on:
16-
labels: ubuntu-20.04-16core
49+
labels: ${{ needs.check-cache.outputs.runner }}
1750
steps:
1851
- name: Check out repository code
1952
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3

.github/workflows/run_rust_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
set -eux
44
set -o pipefail
55

6-
bazel query "filter('.mlir.test$', //tests/tfhe_rust/end_to_end/...)" | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
6+
bazel query "filter('.mlir.test$', //tests/tfhe_rust/end_to_end/...)" | xargs bazel test -c fastbuild --sandbox_writable_path=$HOME/.cargo "$@"

bazel/import_llvm.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ load(
77

88
def import_llvm(name):
99
"""Imports LLVM."""
10-
LLVM_COMMIT = "c166a43c6e6157b1309ea757324cc0a71c078e66"
10+
LLVM_COMMIT = "88c830a1a5687bec597ca947159e4dd9a3f2ac2d"
1111

1212
new_git_repository(
1313
name = name,

include/Dialect/Secret/IR/SecretOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
66
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
77
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
89
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
910
#include "mlir/include/mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
1011
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
@@ -37,6 +38,8 @@ namespace secret {
3738
std::pair<GenericOp, GenericOp> extractOpAfterGeneric(
3839
GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter);
3940

41+
void populateGenericCanonicalizers(RewritePatternSet &patterns,
42+
MLIRContext *ctx);
4043
} // namespace secret
4144
} // namespace heir
4245
} // namespace mlir

include/Dialect/Secret/IR/SecretPatterns.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,20 @@ struct CaptureAmbientScope : public OpRewritePattern<GenericOp> {
128128
};
129129

130130
// Find two adjacent generic ops and merge them into one.
131+
// Accepts a parent op to apply this pattern only to generics descending from
132+
// that op.
131133
struct MergeAdjacentGenerics : public OpRewritePattern<GenericOp> {
132-
MergeAdjacentGenerics(mlir::MLIRContext *context)
133-
: OpRewritePattern<GenericOp>(context, /*benefit=*/1) {}
134+
MergeAdjacentGenerics(mlir::MLIRContext *context,
135+
std::optional<Operation *> parentOp = std::nullopt)
136+
: OpRewritePattern<GenericOp>(context, /*benefit=*/1),
137+
parentOp(parentOp) {}
134138

135139
public:
136140
LogicalResult matchAndRewrite(GenericOp op,
137141
PatternRewriter &rewriter) const override;
142+
143+
private:
144+
std::optional<Operation *> parentOp;
138145
};
139146

140147
// Find a memeref that is stored to in the body of the generic, but not
@@ -201,6 +208,17 @@ struct HoistOpAfterGeneric : public OpRewritePattern<GenericOp> {
201208
std::vector<std::string> opTypes;
202209
};
203210

211+
// Identify the earliest op inside a generic that relies only on plaintext
212+
// operands, and hoist it out of the generic.
213+
struct HoistPlaintextOps : public OpRewritePattern<GenericOp> {
214+
HoistPlaintextOps(mlir::MLIRContext *context)
215+
: OpRewritePattern<GenericOp>(context, /*benefit=*/1) {}
216+
217+
public:
218+
LogicalResult matchAndRewrite(GenericOp op,
219+
PatternRewriter &rewriter) const override;
220+
};
221+
204222
} // namespace secret
205223
} // namespace heir
206224
} // namespace mlir

include/Target/Verilog/VerilogEmitter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class VerilogEmitter {
6262

6363
/// Map from a Value to name of Verilog variable that is bound to the value.
6464
llvm::DenseMap<Value, std::string> value_to_wire_name_;
65+
llvm::SmallVector<std::string> output_wire_names_;
6566

6667
// Globally unique identifiers for values
6768
int64_t value_count_;
@@ -116,6 +117,7 @@ class VerilogEmitter {
116117
LogicalResult emitType(Type type);
117118
LogicalResult emitType(Type type, raw_ostream &os);
118119
LogicalResult emitIndexType(Value indexValue, raw_ostream &os);
120+
119121
// Emit a Verilog array shape specifier of the form `[width]`
120122
LogicalResult emitArrayShapeSuffix(Type type);
121123

@@ -130,6 +132,8 @@ class VerilogEmitter {
130132
StringRef getOrCreateName(BlockArgument arg);
131133
StringRef getOrCreateName(Value value);
132134
StringRef getOrCreateName(Value value, std::string_view prefix);
135+
StringRef getOrCreateOutputWireName(int resultIndex);
136+
StringRef getOutputWireName(int resultIndex);
133137
StringRef getName(Value value);
134138
};
135139

include/Transforms/YosysOptimizer/YosysOptimizer.h

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ namespace mlir {
77
namespace heir {
88

99
std::unique_ptr<mlir::Pass> createYosysOptimizer(
10-
const std::string &yosysFilesPath, const std::string &abcPath,
11-
bool abcFast);
10+
const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast,
11+
int unrollFactor = 0);
1212

1313
#define GEN_PASS_DECL
1414
#include "include/Transforms/YosysOptimizer/YosysOptimizer.h.inc"
@@ -18,6 +18,19 @@ struct YosysOptimizerPipelineOptions
1818
PassOptions::Option<bool> abcFast{*this, "abc-fast",
1919
llvm::cl::desc("Run abc in fast mode."),
2020
llvm::cl::init(false)};
21+
22+
PassOptions::Option<int> unrollFactor{
23+
*this, "unroll-factor",
24+
llvm::cl::desc("Unroll loops by a given factor before optimizing. A "
25+
"value of zero (default) prevents unrolling."),
26+
llvm::cl::init(0)};
27+
};
28+
29+
struct UnrollAndOptimizePipelineOptions
30+
: public PassPipelineOptions<UnrollAndOptimizePipelineOptions> {
31+
PassOptions::Option<bool> abcFast{*this, "abc-fast",
32+
llvm::cl::desc("Run abc in fast mode."),
33+
llvm::cl::init(false)};
2134
};
2235

2336
// registerYosysOptimizerPipeline registers a Yosys pipeline pass using
@@ -26,6 +39,12 @@ struct YosysOptimizerPipelineOptions
2639
void registerYosysOptimizerPipeline(const std::string &yosysFilesPath,
2740
const std::string &abcPath);
2841

42+
// Registers a pipeline that interleaves yosys-optimizer and loop unrolling and
43+
// prints statistics about the optimized circuits. Intended for offline analysis
44+
// to determine the best loop-unrolling factor.
45+
void registerUnrollAndOptimizeAnalysisPipeline(
46+
const std::string &yosysFilesPath, const std::string &abcPath);
47+
2948
} // namespace heir
3049
} // namespace mlir
3150

include/Transforms/YosysOptimizer/YosysOptimizer.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,19 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
1313
Note that booleanization changes the function signature: multi-bit integers
1414
are transformed to a tensor of booleans, for example, an `i8` is converted
1515
to `tensor<8xi1>`.
16+
17+
The optimizer will be applied to each `secret.generic` op containing
18+
arithmetic ops that can be optimized.
19+
20+
Optional parameters:
21+
22+
- `abc-fast`: Run the abc optimizer in "fast" mode, getting faster compile
23+
time at the expense of a possibly larger output circuit.
24+
- `unroll-factor`: Before optimizing the circuit, unroll loops by a given
25+
factor. If unset, this pass will not unroll any loops.
1626
}];
27+
// TODO(https://github.com/google/heir/issues/257): add option for the pass to select
28+
// the unroll factor automatically.
1729

1830
let dependentDialects = [
1931
"mlir::arith::ArithDialect",
@@ -23,4 +35,22 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
2335
];
2436
}
2537

38+
def UnrollAndOptimizeAnalysis : Pass<"unroll-and-optimize-analysis"> {
39+
let summary = "Iteratively unroll and optimize an IR, printing optimization stats.";
40+
41+
let description = [{
42+
This pass invokes the `--yosys-optimizer` pass while iteratively applying
43+
loop-unrolling. Along the way, it prints statistics about the optimized
44+
circuits, which can be used to determine an optimal loop-unrolling factor
45+
for a given program.
46+
}];
47+
48+
let dependentDialects = [
49+
"mlir::arith::ArithDialect",
50+
"mlir::heir::comb::CombDialect",
51+
"mlir::heir::secret::SecretDialect",
52+
"mlir::tensor::TensorDialect"
53+
];
54+
}
55+
2656
#endif // INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_TD_

lib/Dialect/Secret/IR/SecretOps.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ YieldOp GenericOp::getYieldOp() {
275275
return *getBody()->getOps<YieldOp>().begin();
276276
}
277277

278-
GenericOp cloneWithNewTypes(GenericOp op, TypeRange newTypes,
279-
PatternRewriter &rewriter) {
278+
GenericOp cloneWithNewResultTypes(GenericOp op, TypeRange newTypes,
279+
PatternRewriter &rewriter) {
280280
return rewriter.create<GenericOp>(
281281
op.getLoc(), op.getOperands(), newTypes,
282282
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
@@ -299,7 +299,7 @@ std::pair<GenericOp, ValueRange> GenericOp::addNewYieldedValues(
299299
SecretType newTy = secret::SecretType::get(t);
300300
return newTy;
301301
}));
302-
GenericOp newOp = cloneWithNewTypes(*this, newTypes, rewriter);
302+
GenericOp newOp = cloneWithNewResultTypes(*this, newTypes, rewriter);
303303

304304
auto newResultStartIter = newOp.getResults().drop_front(
305305
newOp.getNumResults() - newValuesToYield.size());
@@ -338,7 +338,7 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
338338
return newTy;
339339
}));
340340

341-
return cloneWithNewTypes(*this, newResultTypes, rewriter);
341+
return cloneWithNewResultTypes(*this, newResultTypes, rewriter);
342342
}
343343

344344
GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
@@ -369,12 +369,16 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
369369
return newTy;
370370
}));
371371

372-
return cloneWithNewTypes(*this, newResultTypes, rewriter);
372+
return cloneWithNewResultTypes(*this, newResultTypes, rewriter);
373373
}
374374

375375
GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
376376
PatternRewriter &rewriter) {
377377
assert(opToExtract->getParentOp() == *this);
378+
LLVM_DEBUG({
379+
llvm::dbgs() << "Extracting:\n";
380+
opToExtract->dump();
381+
});
378382

379383
// Result types are secret versions of the results of the op, since the
380384
// secret will yield all of this op's results immediately.
@@ -394,6 +398,10 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
394398
auto *newOp = b.clone(*opToExtract, mp);
395399
b.create<YieldOp>(loc, newOp->getResults());
396400
});
401+
LLVM_DEBUG({
402+
llvm::dbgs() << "After adding new single-op generic:\n";
403+
newGeneric->getParentOp()->dump();
404+
});
397405

398406
// Once the op is split off into a new generic op, we need to add new
399407
// operands to the old generic op, add new corresponding block arguments, and
@@ -412,6 +420,13 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
412420
return newGeneric;
413421
}
414422

423+
void populateGenericCanonicalizers(RewritePatternSet &patterns,
424+
MLIRContext *ctx) {
425+
patterns.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
426+
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs,
427+
HoistPlaintextOps>(ctx);
428+
}
429+
415430
// When replacing a generic op with a new one, and given an op in the original
416431
// generic op, find the corresponding op in the new generic op.
417432
//
@@ -573,8 +588,7 @@ void GenericOp::inlineInPlaceDroppingSecrets(PatternRewriter &rewriter,
573588

574589
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
575590
MLIRContext *context) {
576-
results.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
577-
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs>(context);
591+
populateGenericCanonicalizers(results, context);
578592
}
579593

580594
} // namespace secret

0 commit comments

Comments
 (0)