Skip to content

Commit

Permalink
Merge pull request #429 from asraa:test-pr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604980250
  • Loading branch information
asraa committed Feb 7, 2024
2 parents d8874af + ee0dfe6 commit f6472d4
Show file tree
Hide file tree
Showing 18 changed files with 772 additions and 68 deletions.
1 change: 1 addition & 0 deletions .github/actionlint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
self-hosted-runner:
labels:
- ubuntu-20.04-16core
- ubuntu-20.04-4core
35 changes: 34 additions & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,42 @@ concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
check-cache:
runs-on:
labels: ubuntu-20.04-4core
outputs:
runner: ${{ steps.runner.outputs.runner }}
steps:
- name: Check out repository code
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3

- name: Cache bazel build artifacts
id: cache
uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # [email protected]
with:
path: |
~/.cache/bazel
key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}-${{ hashFiles('bazel/import_llvm.bzl') }}
restore-keys: |
${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}-
lookup-only: true
- name: Select runner
id: runner
env:
CACHE_HIT: ${{ steps.cache.outputs.cache-hit == 'true' }}
run: |
set -euo pipefail
if [[ "${CACHE_HIT}" == "true" ]]; then
echo "runner=ubuntu-20.04-4core" >> "$GITHUB_OUTPUT"
else
echo "runner=ubuntu-20.04-16core" >> "$GITHUB_OUTPUT"
fi
build-and-test:
needs: check-cache
runs-on:
labels: ubuntu-20.04-16core
labels: ${{ needs.check-cache.outputs.runner }}
steps:
- name: Check out repository code
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_rust_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
set -eux
set -o pipefail

bazel query "filter('.mlir.test$', //tests/tfhe_rust/end_to_end/...)" | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
bazel query "filter('.mlir.test$', //tests/tfhe_rust/end_to_end/...)" | xargs bazel test -c fastbuild --sandbox_writable_path=$HOME/.cargo "$@"
2 changes: 1 addition & 1 deletion bazel/import_llvm.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ load(

def import_llvm(name):
"""Imports LLVM."""
LLVM_COMMIT = "c166a43c6e6157b1309ea757324cc0a71c078e66"
LLVM_COMMIT = "88c830a1a5687bec597ca947159e4dd9a3f2ac2d"

new_git_repository(
name = name,
Expand Down
3 changes: 3 additions & 0 deletions include/Dialect/Secret/IR/SecretOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
Expand Down Expand Up @@ -37,6 +38,8 @@ namespace secret {
std::pair<GenericOp, GenericOp> extractOpAfterGeneric(
GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter);

void populateGenericCanonicalizers(RewritePatternSet &patterns,
MLIRContext *ctx);
} // namespace secret
} // namespace heir
} // namespace mlir
Expand Down
22 changes: 20 additions & 2 deletions include/Dialect/Secret/IR/SecretPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,20 @@ struct CaptureAmbientScope : public OpRewritePattern<GenericOp> {
};

// Find two adjacent generic ops and merge them into one.
// Accepts a parent op to apply this pattern only to generics descending from
// that op.
struct MergeAdjacentGenerics : public OpRewritePattern<GenericOp> {
MergeAdjacentGenerics(mlir::MLIRContext *context)
: OpRewritePattern<GenericOp>(context, /*benefit=*/1) {}
MergeAdjacentGenerics(mlir::MLIRContext *context,
std::optional<Operation *> parentOp = std::nullopt)
: OpRewritePattern<GenericOp>(context, /*benefit=*/1),
parentOp(parentOp) {}

public:
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override;

private:
std::optional<Operation *> parentOp;
};

// Find a memeref that is stored to in the body of the generic, but not
Expand Down Expand Up @@ -201,6 +208,17 @@ struct HoistOpAfterGeneric : public OpRewritePattern<GenericOp> {
std::vector<std::string> opTypes;
};

// Identify the earliest op inside a generic that relies only on plaintext
// operands, and hoist it out of the generic.
struct HoistPlaintextOps : public OpRewritePattern<GenericOp> {
HoistPlaintextOps(mlir::MLIRContext *context)
: OpRewritePattern<GenericOp>(context, /*benefit=*/1) {}

public:
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override;
};

} // namespace secret
} // namespace heir
} // namespace mlir
Expand Down
4 changes: 4 additions & 0 deletions include/Target/Verilog/VerilogEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VerilogEmitter {

/// Map from a Value to name of Verilog variable that is bound to the value.
llvm::DenseMap<Value, std::string> value_to_wire_name_;
llvm::SmallVector<std::string> output_wire_names_;

// Globally unique identifiers for values
int64_t value_count_;
Expand Down Expand Up @@ -116,6 +117,7 @@ class VerilogEmitter {
LogicalResult emitType(Type type);
LogicalResult emitType(Type type, raw_ostream &os);
LogicalResult emitIndexType(Value indexValue, raw_ostream &os);

// Emit a Verilog array shape specifier of the form `[width]`
LogicalResult emitArrayShapeSuffix(Type type);

Expand All @@ -130,6 +132,8 @@ class VerilogEmitter {
StringRef getOrCreateName(BlockArgument arg);
StringRef getOrCreateName(Value value);
StringRef getOrCreateName(Value value, std::string_view prefix);
StringRef getOrCreateOutputWireName(int resultIndex);
StringRef getOutputWireName(int resultIndex);
StringRef getName(Value value);
};

Expand Down
23 changes: 21 additions & 2 deletions include/Transforms/YosysOptimizer/YosysOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace mlir {
namespace heir {

std::unique_ptr<mlir::Pass> createYosysOptimizer(
const std::string &yosysFilesPath, const std::string &abcPath,
bool abcFast);
const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast,
int unrollFactor = 0);

#define GEN_PASS_DECL
#include "include/Transforms/YosysOptimizer/YosysOptimizer.h.inc"
Expand All @@ -18,6 +18,19 @@ struct YosysOptimizerPipelineOptions
PassOptions::Option<bool> abcFast{*this, "abc-fast",
llvm::cl::desc("Run abc in fast mode."),
llvm::cl::init(false)};

PassOptions::Option<int> unrollFactor{
*this, "unroll-factor",
llvm::cl::desc("Unroll loops by a given factor before optimizing. A "
"value of zero (default) prevents unrolling."),
llvm::cl::init(0)};
};

struct UnrollAndOptimizePipelineOptions
: public PassPipelineOptions<UnrollAndOptimizePipelineOptions> {
PassOptions::Option<bool> abcFast{*this, "abc-fast",
llvm::cl::desc("Run abc in fast mode."),
llvm::cl::init(false)};
};

// registerYosysOptimizerPipeline registers a Yosys pipeline pass using
Expand All @@ -26,6 +39,12 @@ struct YosysOptimizerPipelineOptions
void registerYosysOptimizerPipeline(const std::string &yosysFilesPath,
const std::string &abcPath);

// Registers a pipeline that interleaves yosys-optimizer and loop unrolling and
// prints statistics about the optimized circuits. Intended for offline analysis
// to determine the best loop-unrolling factor.
void registerUnrollAndOptimizeAnalysisPipeline(
const std::string &yosysFilesPath, const std::string &abcPath);

} // namespace heir
} // namespace mlir

Expand Down
30 changes: 30 additions & 0 deletions include/Transforms/YosysOptimizer/YosysOptimizer.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
Note that booleanization changes the function signature: multi-bit integers
are transformed to a tensor of booleans, for example, an `i8` is converted
to `tensor<8xi1>`.

The optimizer will be applied to each `secret.generic` op containing
arithmetic ops that can be optimized.

Optional parameters:

- `abc-fast`: Run the abc optimizer in "fast" mode, getting faster compile
time at the expense of a possibly larger output circuit.
- `unroll-factor`: Before optimizing the circuit, unroll loops by a given
factor. If unset, this pass will not unroll any loops.
}];
// TODO(https://github.com/google/heir/issues/257): add option for the pass to select
// the unroll factor automatically.

let dependentDialects = [
"mlir::arith::ArithDialect",
Expand All @@ -23,4 +35,22 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
];
}

def UnrollAndOptimizeAnalysis : Pass<"unroll-and-optimize-analysis"> {
let summary = "Iteratively unroll and optimize an IR, printing optimization stats.";

let description = [{
This pass invokes the `--yosys-optimizer` pass while iteratively applying
loop-unrolling. Along the way, it prints statistics about the optimized
circuits, which can be used to determine an optimal loop-unrolling factor
for a given program.
}];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::comb::CombDialect",
"mlir::heir::secret::SecretDialect",
"mlir::tensor::TensorDialect"
];
}

#endif // INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_TD_
28 changes: 21 additions & 7 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ YieldOp GenericOp::getYieldOp() {
return *getBody()->getOps<YieldOp>().begin();
}

GenericOp cloneWithNewTypes(GenericOp op, TypeRange newTypes,
PatternRewriter &rewriter) {
GenericOp cloneWithNewResultTypes(GenericOp op, TypeRange newTypes,
PatternRewriter &rewriter) {
return rewriter.create<GenericOp>(
op.getLoc(), op.getOperands(), newTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
Expand All @@ -299,7 +299,7 @@ std::pair<GenericOp, ValueRange> GenericOp::addNewYieldedValues(
SecretType newTy = secret::SecretType::get(t);
return newTy;
}));
GenericOp newOp = cloneWithNewTypes(*this, newTypes, rewriter);
GenericOp newOp = cloneWithNewResultTypes(*this, newTypes, rewriter);

auto newResultStartIter = newOp.getResults().drop_front(
newOp.getNumResults() - newValuesToYield.size());
Expand Down Expand Up @@ -338,7 +338,7 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
return newTy;
}));

return cloneWithNewTypes(*this, newResultTypes, rewriter);
return cloneWithNewResultTypes(*this, newResultTypes, rewriter);
}

GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
Expand Down Expand Up @@ -369,12 +369,16 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
return newTy;
}));

return cloneWithNewTypes(*this, newResultTypes, rewriter);
return cloneWithNewResultTypes(*this, newResultTypes, rewriter);
}

GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
PatternRewriter &rewriter) {
assert(opToExtract->getParentOp() == *this);
LLVM_DEBUG({
llvm::dbgs() << "Extracting:\n";
opToExtract->dump();
});

// Result types are secret versions of the results of the op, since the
// secret will yield all of this op's results immediately.
Expand All @@ -394,6 +398,10 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
});
LLVM_DEBUG({
llvm::dbgs() << "After adding new single-op generic:\n";
newGeneric->getParentOp()->dump();
});

// Once the op is split off into a new generic op, we need to add new
// operands to the old generic op, add new corresponding block arguments, and
Expand All @@ -412,6 +420,13 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
return newGeneric;
}

void populateGenericCanonicalizers(RewritePatternSet &patterns,
MLIRContext *ctx) {
patterns.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs,
HoistPlaintextOps>(ctx);
}

// When replacing a generic op with a new one, and given an op in the original
// generic op, find the corresponding op in the new generic op.
//
Expand Down Expand Up @@ -573,8 +588,7 @@ void GenericOp::inlineInPlaceDroppingSecrets(PatternRewriter &rewriter,

void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs>(context);
populateGenericCanonicalizers(results, context);
}

} // namespace secret
Expand Down
Loading

0 comments on commit f6472d4

Please sign in to comment.