Skip to content

Commit

Permalink
linalg-to-tensor-ext: reorder kernel lowering to avoid plaintext/ciph…
Browse files Browse the repository at this point in the history
…ertext type issues

I don't think there's a way to avoid this - I moved final computations outside of the loop to be before the loop, so that the iter_arg of the loop can be a ciphertext type to match with the yielded ciphertext type.

Before, the iter_arg of the loop was the initial bias, a plaintext type, which meant that we had a type issue at secret-to-ckks since the iter arg of a for loop must match the yielded type (which will become a ciphertext once the partial matmul result is added to it)

Related to #1338

PiperOrigin-RevId: 728774577
  • Loading branch information
asraa authored and copybara-github committed Feb 19, 2025
1 parent c10f98f commit 6b8f58f
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,36 @@ Value multiplyDiagonalizedMatrixWithVector(
}
SmallVector<OpFoldResult> strides(2, builder.getIndexAttr(1));

// Setup the offsets for the last ExtractSliceOp and build the
// ExtractSliceOp.
SmallVector<OpFoldResult> firstOffsets(2, builder.getIndexAttr(0));
auto firstExtracted = builder.create<tensor::ExtractSliceOp>(
diagonalizedMatrix, firstOffsets, sizes, strides);

// Calculates the first scalar multiplication and sum.
auto firstMultiplied = builder.create<MulOp>(secretValues, firstExtracted);
auto firstSumWithoutRotateAndSum =
builder.create<AddOp>(bias, firstMultiplied);

// Build the affine for loop.
// Setup parameters for the affine for loop.
SmallVector<Value> iterArgs({bias, secretValues});
int numLoops = originalMatrixDimensions[0];
if (numLoops > originalMatrixDimensions[1]) {
numLoops = originalMatrixDimensions[1];
}

// Build the affine for loop.
SmallVector<Value> iterArgs({firstSumWithoutRotateAndSum, secretValues});
auto forOp =
builder.create<mlir::affine::AffineForOp>(0, numLoops - 1, 1, iterArgs);
builder.create<mlir::affine::AffineForOp>(1, numLoops, 1, iterArgs);

// Now, we are inside for loop.
builder.setInsertionPointToStart(forOp.getBody());
auto index = forOp.getInductionVar();
auto sum = forOp.getRegionIterArgs()[0];
auto rotatedVector = forOp.getRegionIterArgs()[1];

// Rotate first
auto rotatedVector = builder.create<tensor_ext::RotateOp>(
forOp.getRegionIterArgs()[1], indexOne);

// Setup the offsets for the ExtractSliceOp and build the ExtractSliceOp.
SmallVector<OpFoldResult> offsets(2);
Expand All @@ -223,32 +237,11 @@ Value multiplyDiagonalizedMatrixWithVector(

auto multiplied = builder.create<MulOp>(rotatedVector, extracted);
auto newSum = builder.create<AddOp>(sum, multiplied);
auto newRotatedVector =
builder.create<tensor_ext::RotateOp>(rotatedVector, indexOne);
builder.create<affine::AffineYieldOp>(ValueRange({newSum, newRotatedVector}));
builder.create<affine::AffineYieldOp>(ValueRange({newSum, rotatedVector}));

// Now outside for loop.
builder.setInsertionPointAfter(forOp);

// Setup the offsets for the last ExtractSliceOp and build the
// ExtractSliceOp.
SmallVector<OpFoldResult> lastOffsets(2);
if (isLeftOperandSecret) {
lastOffsets = {builder.getIndexAttr(originalMatrixDimensions[0] - 1),
builder.getIndexAttr(0)};
} else {
lastOffsets = {builder.getIndexAttr(0),
builder.getIndexAttr(originalMatrixDimensions[0] - 1)};
}
auto lastExtracted = builder.create<tensor::ExtractSliceOp>(
diagonalizedMatrix, lastOffsets, sizes, strides);

// Calculates the final scalar multiplication and sum.
auto lastMultiplied =
builder.create<MulOp>(forOp.getResults()[1], lastExtracted);
auto finalSumWithoutRotateAndSum =
builder.create<AddOp>(forOp.getResults()[0], lastMultiplied);

int numRotationsAndSums;
if (isLeftOperandSecret) {
numRotationsAndSums = llvm::APInt(32, originalMatrixDimensions[0] /
Expand All @@ -261,7 +254,7 @@ Value multiplyDiagonalizedMatrixWithVector(
}

// Rotate and sum if needed
Value sumInProgress = finalSumWithoutRotateAndSum;
Value sumInProgress = forOp.getResults()[0];
int rotationValue = maxTilingSize;
for (int i = 0; i < numRotationsAndSums; ++i) {
rotationValue /= 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// CHECK-SAME{LITERAL}: <[[
// CHECK-SAME: 1.{{0*}}e+00, 2.{{0*}}e+00, 3.{{0*}}e+00, 4.{{0*}}e+00], [2.{{0*}}e+00, 3.{{0*}}e+00, 4.{{0*}}e+00, 1.{{0*}}e+00], [3.{{0*}}e+00, 4.{{0*}}e+00, 1.{{0*}}e+00, 2.{{0*}}e+00], [4.{{0*}}e+00, 1.{{0*}}e+00, 2.{{0*}}e+00, 3.{{0*}}e+00
// CHECK-SAME{LITERAL}: ]]>
// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1]
// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [1, 4] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<1x4xf32>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<1x4xf32>):
// CHECK: %[[MUL:.*]] = arith.mulf %[[ARG_CONVERTED]], %[[SLICE]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
// CHECK-SAME{LITERAL}: <[[
// CHECK-SAME: 1.7{{0*}}e+01, 1.8{{0*}}e+01, 1.9{{0*}}e+01, 2.{{0*}}e+01
// CHECK-SAME{LITERAL}: ]]>
// CHECK: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1]
// CHECK: %[[FIRST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [1, 4] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<1x4xf16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<1x4xf16>):
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[FIRST_MUL:.*]] = arith.mulf %[[ARG_CONVERTED]], %[[FIRST_SLICE]]
// CHECK: %[[FIRST_SUM:.*]] = arith.addf %[[FIRST_MUL]], %[[BIAS]]
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 1 to 4 iter_args(%[[RUNNING_SUM:.*]] = %[[FIRST_SUM]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1]
// CHECK: %[[MUL:.*]] = arith.mulf %[[ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[MUL:.*]] = arith.mulf %[[UPDATED_ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addf %[[RUNNING_SUM]], %[[MUL]]
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: %[[LAST_MUL:.*]] = arith.mulf %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addf %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]]
// CHECK: secret.yield %[[FINAL_SUM]]
// CHECK: secret.yield %[[FOR_LOOP_OUT]]#0
// CHECK: return %[[OUT]]
module {
func.func @test_float_vector_square_matrix_matmul(%vec : !secret.secret<tensor<1x4xf16>>) -> !secret.secret<tensor<1x4xf16>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
// CHECK-SAME{LITERAL}: <[[17], [18], [17], [18]]> : tensor<4x1xi16>
// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense
// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [3, 4, 1, 2], [8, 5, 6, 7]]> : tensor<4x4xi16>
// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 1] [4, 1] [1, 1]
// CHECK-DAG: %[[FIRST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [4, 1] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<4x1xi16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>):
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 1 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[FIRST_MUL:.*]] = arith.muli %[[ARG_CONVERTED]], %[[FIRST_SLICE]]
// CHECK: %[[FIRST_SUM:.*]] = arith.addi %[[FIRST_MUL]], %[[BIAS]]
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 1 to 2 iter_args(%[[RUNNING_SUM:.*]] = %[[FIRST_SUM]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK-DAG: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[UPDATED_ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]]
// CHECK: %[[BEFORE_ROTATE_AND_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]]
// CHECK: %[[ROTATED_SUM:.*]] = tensor_ext.rotate %[[BEFORE_ROTATE_AND_SUM]], %[[TWO]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[BEFORE_ROTATE_AND_SUM]], %[[ROTATED_SUM]]
// CHECK: %[[ROTATED_SUM:.*]] = tensor_ext.rotate %[[FOR_LOOP_OUT]]#0, %[[TWO]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[ROTATED_SUM]]
// CHECK: secret.yield %[[FINAL_SUM]]
// CHECK: return %[[OUT]]
module {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]> : tensor<4x4xi16>
// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense
// CHECK-SAME{LITERAL}: <[[17], [18], [19], [20]]> : tensor<4x1xi16>
// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 3] [4, 1] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<4x1xi16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>):
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]]
// CHECK: secret.yield %[[FINAL_SUM]]
// CHECK-DAG: %[[FIRST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [4, 1] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<4x1xi16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>):
// CHECK: %[[FIRST_MUL:.*]] = arith.muli %[[ARG_CONVERTED]], %[[FIRST_SLICE]]
// CHECK: %[[FIRST_SUM:.*]] = arith.addi %[[FIRST_MUL]], %[[BIAS]]
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 1 to 4 iter_args(%[[RUNNING_SUM:.*]] = %[[FIRST_SUM]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK-DAG: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[UPDATED_ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: secret.yield %[[FOR_LOOP_OUT]]#0
// CHECK: return %[[OUT]]
module {
func.func @test_integer_square_matrix_vector_matmul(%vec : !secret.secret<tensor<4x1xi16>>) -> !secret.secret<tensor<4x1xi16>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]> : tensor<4x4xi16>
// CHECK: %[[BIAS:.*]] = arith.constant dense
// CHECK-SAME{LITERAL}: <[[17], [18], [19], [20]]> : tensor<4x1xi16>
// CHECK: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 3] [4, 1] [1, 1]
// CHECK: %[[FIRST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [4, 1] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<4x1xi16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>):
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]]
// CHECK: secret.yield %[[FINAL_SUM]]
// CHECK: %[[FIRST_MUL:.*]] = arith.muli %[[ARG_CONVERTED]], %[[FIRST_SLICE]]
// CHECK: %[[FIRST_SUM:.*]] = arith.addi %[[FIRST_MUL]], %[[BIAS]]
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 1 to 4 iter_args(%[[RUNNING_SUM:.*]] = %[[FIRST_SUM]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK-DAG: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[UPDATED_ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: secret.yield %[[FOR_LOOP_OUT]]#0
// CHECK: return %[[OUT]]
module {
func.func @test_integer_vector_square_matrix_matmul(%vec : !secret.secret<tensor<4x1xi16>>) -> !secret.secret<tensor<4x1xi16>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
// CHECK-SAME{LITERAL}: <[[1, 6, 11, 16], [5, 10, 15, 4], [9, 14, 3, 8], [13, 2, 7, 12]]> : tensor<4x4xi16>
// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense
// CHECK-SAME{LITERAL}: <[[17, 18, 19, 20]]> : tensor<1x4xi16>
// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1]
// CHECK-DAG: %[[FIRST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 0] [1, 4] [1, 1]
// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret<tensor<1x4xi16>>)
// CHECK: ^body(%[[ARG_CONVERTED:.*]]: tensor<1x4xi16>):
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]]
// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]]
// CHECK: secret.yield %[[FINAL_SUM]]
// CHECK: %[[FIRST_MUL:.*]] = arith.muli %[[ARG_CONVERTED]], %[[FIRST_SLICE]]
// CHECK: %[[FIRST_SUM:.*]] = arith.addi %[[FIRST_MUL]], %[[BIAS]]
// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 1 to 4 iter_args(%[[RUNNING_SUM:.*]] = %[[FIRST_SUM]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]])
// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]]
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1]
// CHECK: %[[MUL:.*]] = arith.muli %[[UPDATED_ROTATED_VEC]], %[[SLICE]]
// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]]
// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]]
// CHECK: secret.yield %[[FOR_LOOP_OUT]]#0
// CHECK: return %[[OUT]]
module {
func.func @test_integer_vector_square_matrix_matmul(%vec : !secret.secret<tensor<1x4xi16>>) -> !secret.secret<tensor<1x4xi16>> {
Expand Down

0 comments on commit 6b8f58f

Please sign in to comment.