diff --git a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp index 58b0386e11..a246cfbeff 100644 --- a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp +++ b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp @@ -17,10 +17,14 @@ #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project @@ -98,11 +102,40 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { // of the ciphertext at that point in the computation, as well as the decision // variable to track whether to insert a relinearization operation after the // operation. - opToRunOn->walk([&](Operation* op) { + opToRunOn->walk([&](Operation* op) -> WalkResult { + // Skipping inner loop bodies because they will be handled by a ILP solver + // in bottom-up order. But, we still need to create variables for the inner + // loop's results so the outer solver knows about the inner loop. + if (isa(op) && op != opToRunOn) { + std::string name = uniqueName(op); + if (isSecret(op->getResults(), solver)) { + auto decisionVar = model.AddBinaryVariable("InsertRelin_" + name); + decisionVariables.insert(std::make_pair(op, decisionVar)); + } + + for (OpResult opResult : op->getOpResults()) { + Value result = opResult; + if (!isSecret(result, solver)) { + continue; + } + std::string varName = + "Degree_" + name + "_" + std::to_string(opResult.getResultNumber()); + auto keyBasisVar = + model.AddContinuousVariable(0, MAX_KEY_BASIS_DEGREE, varName); + keyBasisVars.insert(std::make_pair(result, keyBasisVar)); + + std::string brVarName = varName + "_br"; + auto brKeyBasisVar = + model.AddContinuousVariable(0, MAX_KEY_BASIS_DEGREE, brVarName); + beforeRelinVars.insert(std::make_pair(result, brKeyBasisVar)); + } + return WalkResult::skip(); + } + std::string name = uniqueName(op); if (isa(op)) { - return; + return WalkResult::advance(); } // skip secret generic op; we decide inside generic op block @@ -138,7 +171,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { // linearized, though this could be generalized to read the degree from the // type. if (op->getNumRegions() == 0) { - return; + return WalkResult::advance(); } LLVM_DEBUG(llvm::dbgs() @@ -146,8 +179,22 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { for (Region& region : op->getRegions()) { for (Block& block : region.getBlocks()) { for (BlockArgument arg : block.getArguments()) { - if (!isSecret(arg, solver)) { - continue; + bool argIsSecret = isSecret(arg, solver); + + // handle iter_args that become secret via yield + if (!argIsSecret) { + if (auto loopOp = dyn_cast(op)) { + auto iterArgs = loopOp.getRegionIterArgs(); + auto it = llvm::find(iterArgs, arg); + if (it != iterArgs.end()) { + unsigned idx = std::distance(iterArgs.begin(), it); + auto yieldedValues = loopOp.getYieldedValues(); + if (idx < yieldedValues.size()) { + argIsSecret = isSecret(yieldedValues[idx], solver); + } + } + } + if (!argIsSecret) continue; } std::stringstream ss; @@ -159,14 +206,22 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } } } + return WalkResult::advance(); }); // Constraints to initialize the key basis degree variables at the start of // the computation. for (auto& [value, var] : keyBasisVars) { if (llvm::isa(value)) { - // If the dimension is 3, the key basis is [0, 1, 2] and the degree is 2. - auto constrainedDegree = getDimension(value, solver).value_or(2) - 1; + auto blockArg = llvm::cast(value); + int constrainedDegree; + // Loop iter_args is always assumed degree 1 since getDimension diverges + if (isa(blockArg.getOwner()->getParentOp())) { + constrainedDegree = 1; + } else { + // If the dimension is 3, the key basis is [0, 1, 2] and the degree is 2. + constrainedDegree = getDimension(value, solver).value_or(2) - 1; + } model.AddLinearConstraint(var == constrainedDegree, ""); } } @@ -179,15 +234,20 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { // through from the input unchanged. If we don't require this, the output // of the addition must be a max over the input degrees. if (!allowMixedDegreeOperands) { - opToRunOn->walk([&](Operation* op) { + opToRunOn->walk([&](Operation* op) -> WalkResult { + // Skip loop bodies — they will be handled by a recursive solver + if (isa(op) && op != opToRunOn) { + return WalkResult::skip(); + } + if (op->getNumOperands() <= 1) { - return; + return WalkResult::advance(); } // secret generic op arguments are not constrained // instead their block arguments are constrained if (isa(op)) { - return; + return WalkResult::advance(); } std::string name = uniqueName(op); @@ -196,7 +256,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { SmallVector secretOperands; getSecretOperands(op, secretOperands, solver); if (secretOperands.size() <= 1) { - return; + return WalkResult::advance(); } auto anchorVar = keyBasisVars.at(secretOperands[0]->get()); @@ -215,13 +275,17 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { << name; model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str()); } + return WalkResult::advance(); }); } // Some ops require a linear key basis. Yield is a special case // where we require returned values from funcs to be linearized. // TODO(#1398): determine whether we need linear key basis for modreduce. - opToRunOn->walk([&](Operation* op) { + opToRunOn->walk([&](Operation* op) -> WalkResult { + if (isa(op) && op != opToRunOn) { + return WalkResult::skip(); + } llvm::TypeSwitch(*op) .Case( [&](auto op) { @@ -244,7 +308,35 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { << operand.getOperandNumber(); model.AddLinearConstraint(operandDegreeVar == 1, ss.str()); } - }); + }) + .Case([&](auto op) { + // For loop yield ops, the degree returned must not exceed the degree + // of the corresponding iter_arg block argument at the start of the + // loop. This prevents unbounded growth across loop iterations. + auto parentLoop = op->getParentOp(); + auto loopLike = dyn_cast(parentLoop); + if (!loopLike) return; + + auto iterArgs = loopLike.getRegionIterArgs(); + + // Number of iter args should match number of yielded operands + if (iterArgs.size() != op.getNumOperands()) return; + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto yieldOperand = op.getOperand(i); + if (!isSecret(yieldOperand, solver)) continue; + if (!keyBasisVars.contains(yieldOperand)) continue; + + auto yieldDegreeVar = keyBasisVars.at(yieldOperand); + auto iterArgDegreeVar = keyBasisVars.at(iterArgs[i]); + + model.AddLinearConstraint( + yieldDegreeVar <= iterArgDegreeVar, + "LoopCarriedDependency_" + std::to_string(i) + "_" + + uniqueName(op)); + } + }); + return WalkResult::advance(); }); // When mixed-degree ops are enabled, the default result degree of an op is @@ -254,7 +346,27 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { std::unordered_set extraVarsForObjective; // Add constraints that set the before_relin variables appropriately - opToRunOn->walk([&](Operation* op) { + opToRunOn->walk([&](Operation* op) -> WalkResult { + // For nested inner loops, apply the LoopOutputDegree constraint using the + // degree solved by the inner loop's ILP, then skip the loop body. + if (isa(op) && op != opToRunOn) { + for (auto [idx, result] : llvm::enumerate(op->getResults())) { + if (!isSecret(result, solver)) continue; + auto resultBeforeRelinVar = beforeRelinVars.at(result); + // Use the degree solved by the inner loop's ILP. + // Default to 1 if not yet populated (loop with no secret ops). + int solvedDegree = 1; + auto it = loopBoundaryDegrees.find(op); + if (it != loopBoundaryDegrees.end() && idx < it->second.size()) { + solvedDegree = it->second[idx]; + } + model.AddLinearConstraint( + resultBeforeRelinVar == solvedDegree, + "LoopOutputDegree_" + uniqueName(op) + "_" + std::to_string(idx)); + } + return WalkResult::skip(); + } + llvm::TypeSwitch(*op) .Case([&](auto op) { // if plain mul, skip @@ -299,7 +411,8 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } }) .Default([&](Operation& op) { - // For any other op, the key basis does not change unless we insert + if (isa(op)) return; + // a relin op. The operands may have the same basis degree, if that // is required by the backend and allowMixedDegreeOperands is false, // in which case we can just forward the degree of the first secret @@ -360,7 +473,8 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } } }); - }); + return WalkResult::advance(); + }); // The objective is to minimize the number of relinearization ops. // TODO(#1018): improve the objective function to account for differing costs @@ -373,7 +487,48 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { model.Minimize(obj); // Add constraints that control the effect of relinearization insertion. - opToRunOn->walk([&](Operation* op) { + opToRunOn->walk([&](Operation* op) -> WalkResult { + // Helper to add DecisionDynamics constraints for an op's results. + auto addDecisionDynamics = [&](Operation* targetOp) { + if (!isSecret(targetOp->getResults(), solver)) return; + for (OpResult opResult : targetOp->getResults()) { + Value result = opResult; + if (!isSecret(result, solver)) continue; + + auto resultBeforeRelinVar = beforeRelinVars.at(result); + auto resultAfterRelinVar = keyBasisVars.at(result); + auto insertRelinOpDecision = decisionVariables.at(targetOp); + + std::string opName = uniqueName(targetOp); + std::string ddPrefix = "DecisionDynamics_" + opName + "_" + + std::to_string(opResult.getResultNumber()); + + model.AddLinearConstraint(resultAfterRelinVar >= insertRelinOpDecision, + ddPrefix + "_1"); + + model.AddLinearConstraint( + resultAfterRelinVar <= 1 + IF_THEN_AUX * (1 - insertRelinOpDecision), + ddPrefix + "_2"); + + model.AddLinearConstraint( + resultAfterRelinVar >= + resultBeforeRelinVar - IF_THEN_AUX * insertRelinOpDecision, + ddPrefix + "_3"); + + model.AddLinearConstraint( + resultAfterRelinVar <= + resultBeforeRelinVar + IF_THEN_AUX * insertRelinOpDecision, + ddPrefix + "_4"); + } + }; + + // For nested inner loops, apply the DecisionDynamics constraints to the + // loop's results, then skip the loop body. + if (isa(op) && op != opToRunOn) { + addDecisionDynamics(op); + return WalkResult::skip(); + } + // We don't need a type switch here because the only difference // between mul and other ops is how the before_relin variable is related to // the operand variables. @@ -386,42 +541,11 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { // secret generic op arguments are not constrained // instead their block arguments are constrained if (isa(op)) { - return; - } - if (!isSecret(op->getResults(), solver)) { - return; + return WalkResult::advance(); } - for (OpResult opResult : op->getResults()) { - Value result = opResult; - auto resultBeforeRelinVar = beforeRelinVars.at(result); - auto resultAfterRelinVar = keyBasisVars.at(result); - auto insertRelinOpDecision = decisionVariables.at(op); - std::string opName = uniqueName(op); - std::string ddPrefix = "DecisionDynamics_" + opName + "_" + - std::to_string(opResult.getResultNumber()); - - cstName = ddPrefix + "_1"; - model.AddLinearConstraint(resultAfterRelinVar >= insertRelinOpDecision, - cstName); - - cstName = ddPrefix + "_2"; - model.AddLinearConstraint( - resultAfterRelinVar <= 1 + IF_THEN_AUX * (1 - insertRelinOpDecision), - cstName); - - cstName = ddPrefix + "_3"; - model.AddLinearConstraint( - resultAfterRelinVar >= - resultBeforeRelinVar - IF_THEN_AUX * insertRelinOpDecision, - cstName); - - cstName = ddPrefix + "_4"; - model.AddLinearConstraint( - resultAfterRelinVar <= - resultBeforeRelinVar + IF_THEN_AUX * insertRelinOpDecision, - cstName); - } + addDecisionDynamics(op); + return WalkResult::advance(); }); // Dump the model diff --git a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h index 49c9a76a95..713bc3f28f 100644 --- a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h +++ b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h @@ -38,6 +38,10 @@ class OptimizeRelinearizationAnalysis { return solutionKeyBasisDegreeBeforeRelin.lookup(value); } + /// Maps a loop operation to its output degrees (one int per loop result). + /// Populated by the inner solver and read by the outer solver. + llvm::DenseMap> loopBoundaryDegrees; + private: Operation* opToRunOn; DataFlowSolver* solver; diff --git a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp index d085a2fab2..9619c90b23 100644 --- a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp +++ b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp @@ -19,6 +19,10 @@ #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project + namespace mlir { namespace heir { @@ -32,27 +36,60 @@ struct OptimizeRelinearization : impl::OptimizeRelinearizationBase { using OptimizeRelinearizationBase::OptimizeRelinearizationBase; - void processSecretGenericOp(secret::GenericOp genericOp, - DataFlowSolver* solver) { - // Remove all relin ops. This makes the IR invalid, because the key basis - // sizes are incorrect. However, the correctness of the ILP ensures the key - // basis sizes are made correct at the end. - genericOp->walk([&](mgmt::RelinearizeOp op) { - op.getResult().replaceAllUsesWith(op.getOperand()); - op.erase(); - }); + void processBlock( + Operation* parentOp, DataFlowSolver* solver, + const DenseMap>& innerLoopDegrees, + DenseMap>& outLoopDegrees) { OptimizeRelinearizationAnalysis analysis( - genericOp, solver, useLocBasedVariableNames, allowMixedDegreeOperands); + parentOp, solver, useLocBasedVariableNames, allowMixedDegreeOperands); + + // Pass the previously solved inner loop degrees to the outer solver + analysis.loopBoundaryDegrees = innerLoopDegrees; + if (failed(analysis.solve())) { - genericOp->emitError("Failed to solve the optimization problem"); + parentOp->emitError("Failed to solve the optimization problem"); return signalPassFailure(); } + + parentOp->walk([&](Operation* op) { + if (isa(op) && op != parentOp) { + return WalkResult::skip(); + } + // If this is a yield op for the current loop, record its solved degrees + // so this loop's parent can use them. + if (isa(op)) { + SmallVector yieldDegrees; + for (Value operand : op->getOperands()) { + if (!isSecret(operand, solver)) { + // Pad the array so the expected index lines up with the loop results + yieldDegrees.push_back(0); + continue; + } + // The output degree of the loop is the degree of the yield operand + // after any relinearization decisions inside the loop are applied. + int degree = analysis.keyBasisDegreeBeforeRelin(operand); + if (auto definingOp = operand.getDefiningOp()) { + if (analysis.shouldInsertRelin(definingOp)) { + degree = 1; + } + } + yieldDegrees.push_back(degree); + } + if (!yieldDegrees.empty()) { + outLoopDegrees[parentOp] = yieldDegrees; + } + } + return WalkResult::advance(); + }); OpBuilder b(&getContext()); + parentOp->walk([&](Operation* op) { + if (isa(op) && op != parentOp) { + return WalkResult::skip(); + } - genericOp->walk([&](Operation* op) { - if (!analysis.shouldInsertRelin(op)) return; + if (!analysis.shouldInsertRelin(op)) return WalkResult::advance(); LLVM_DEBUG(llvm::dbgs() << "Inserting relin after: " << op->getName() << "\n"); @@ -62,12 +99,21 @@ struct OptimizeRelinearization auto reduceOp = mgmt::RelinearizeOp::create(b, op->getLoc(), result); result.replaceAllUsesExcept(reduceOp.getResult(), {reduceOp}); } + return WalkResult::advance(); }); } void runOnOperation() override { Operation* module = getOperation(); + + module->walk([&](secret::GenericOp genericOp) { + genericOp->walk([&](mgmt::RelinearizeOp op) { + op.getResult().replaceAllUsesWith(op.getOperand()); + op.erase( ); + }); + }); + DataFlowSolver solver; dataflow::loadBaselineAnalyses(solver); solver.load(); @@ -79,8 +125,23 @@ struct OptimizeRelinearization return; } - module->walk( - [&](secret::GenericOp op) { processSecretGenericOp(op, &solver); }); + // Maps a loop operation to its output degrees. + DenseMap> loopDegrees; + + // Process all loops bottom-up. + module->walk([&](Operation* op) { + if (auto loopOp = dyn_cast(op)) { + // Only process loops inside a secret.generic + if (loopOp->getParentOfType()) { + processBlock(loopOp, &solver, loopDegrees, loopDegrees); + } + } + }); + + // Finally, process the top-level generic ops + module->walk([&](secret::GenericOp op) { + processBlock(op, &solver, loopDegrees, loopDegrees); + }); // optimize-relinearization will invalidate mgmt attr // so re-annotate it diff --git a/tests/Transforms/optimize_relinearization/loop_accumulator.mlir b/tests/Transforms/optimize_relinearization/loop_accumulator.mlir new file mode 100644 index 0000000000..62e645bc97 --- /dev/null +++ b/tests/Transforms/optimize_relinearization/loop_accumulator.mlir @@ -0,0 +1,89 @@ +// RUN: heir-opt --optimize-relinearization %s | FileCheck %s + +// An accumulator loop with a ct-ct multiplication. +// The relinearize op inside the loop is essential for correctness: +// without it, the degree of %acc grows without bound across iterations. + +// CHECK: func.func @loop_accumulator +// CHECK: secret.generic +// CHECK: affine.for +// CHECK: arith.muli +// CHECK: mgmt.relinearize +// CHECK: affine.yield +// CHECK: secret.yield + +func.func @loop_accumulator(%arg0: !secret.secret>) -> !secret.secret> { + %0 = secret.generic(%arg0: !secret.secret>) { + ^body(%input0: tensor<8xi16>): + %result = affine.for %i = 0 to 10 iter_args(%acc = %input0) -> (tensor<8xi16>) { + // ct-ct multiplication: degree goes from 1 to 2 + %mul = arith.muli %acc, %acc : tensor<8xi16> + %relin = mgmt.relinearize %mul : tensor<8xi16> + affine.yield %relin : tensor<8xi16> + } + secret.yield %result : tensor<8xi16> + } -> !secret.secret> + return %0 : !secret.secret> +} + +// A nested loop where both loops do ct-ct multiplications. + +// CHECK-LABEL: func.func @nested_loop_both_mul +// CHECK: secret.generic +// CHECK: affine.for +// CHECK: affine.for +// CHECK: arith.muli +// CHECK: mgmt.relinearize +// CHECK: affine.yield +// CHECK: arith.muli +// CHECK: mgmt.relinearize +// CHECK: affine.yield +// CHECK: secret.yield + +func.func @nested_loop_both_mul(%arg0: !secret.secret>) -> !secret.secret> { + %0 = secret.generic(%arg0: !secret.secret>) { + ^body(%input0: tensor<8xi16>): + %outer_result = affine.for %i = 0 to 8 iter_args(%outer_acc = %input0) -> (tensor<8xi16>) { + %inner_result = affine.for %j = 0 to 4 iter_args(%inner_acc = %outer_acc) -> (tensor<8xi16>) { + %inner_mul = arith.muli %inner_acc, %inner_acc : tensor<8xi16> + affine.yield %inner_mul : tensor<8xi16> + } + %outer_mul = arith.muli %inner_result, %inner_result : tensor<8xi16> + affine.yield %outer_mul : tensor<8xi16> + } + secret.yield %outer_result : tensor<8xi16> + } -> !secret.secret> + return %0 : !secret.secret> +} + +// A nested loop where the inner loop does ct-pt multiplications. + + +// CHECK-LABEL: func.func @nested_loop_inner_ct_pt +// CHECK: secret.generic +// CHECK: affine.for +// CHECK: affine.for +// CHECK: arith.muli +// CHECK-NOT: mgmt.relinearize +// CHECK: affine.yield +// CHECK: arith.muli +// CHECK: mgmt.relinearize +// CHECK: affine.yield +// CHECK: secret.yield + +func.func @nested_loop_inner_ct_pt(%arg0: !secret.secret>) -> !secret.secret> { + %0 = secret.generic(%arg0: !secret.secret>) { + ^body(%input0: tensor<8xi16>): + %cst = arith.constant dense<2> : tensor<8xi16> + %outer_result = affine.for %i = 0 to 8 iter_args(%outer_acc = %input0) -> (tensor<8xi16>) { + %inner_result = affine.for %j = 0 to 4 iter_args(%inner_acc = %outer_acc) -> (tensor<8xi16>) { + %inner_mul = arith.muli %inner_acc, %cst : tensor<8xi16> + affine.yield %inner_mul : tensor<8xi16> + } + %outer_mul = arith.muli %inner_result, %inner_result : tensor<8xi16> + affine.yield %outer_mul : tensor<8xi16> + } + secret.yield %outer_result : tensor<8xi16> + } -> !secret.secret> + return %0 : !secret.secret> +}