Description
Root cause: Analyses do not understand region bearing ops (for loops) properly, we need an implementation of visitNonControlFlowArguments
most likely.
Reproducer:
$ heir-opt --secret-insert-mgmt-ckks
func.func @loop(%arg0: !secret.secret<tensor<1x1024xf32>>) -> !secret.secret<tensor<1x1024xf32>> {
%0 = secret.generic ins(%arg0 : !secret.secret<tensor<1x1024xf32>>) {
^body(%input0: tensor<1x1024xf32>):
%1 = affine.for %arg1 = 0 to 3 iter_args(%arg3 = %input0) -> (tensor<1x1024xf32>) {
%3 = arith.mulf %arg3, %arg3 : tensor<1x1024xf32>
affine.yield %3 : tensor<1x1024xf32>
}
secret.yield %1 : tensor<1x1024xf32>
} -> !secret.secret<tensor<1x1024xf32>>
return %0 : !secret.secret<tensor<1x1024xf32>>
}
This program has a single ctxt-ctxt multiplication inside, and the result should have a relinearization inserted after the mul so that the block arg and yielded value always have dimension 2. The pass also inserts modreduce before the multiplication even when first mul option is off because it seems that the args are from a mulresult too.
So I think LevelAnalysis, DimensionAnalysis, and MulResultAnalysis all need to be aware of the region bearing operation.
When you run the reproducer, it will run into an infinite loop when it re-runs the solver after the relinearization insertion. I added some debugs to the level analysis and saw it continuining to increment the level of multiplication over and over in the loop.
LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 0 to state 2769
LevelAnalysis: Visiting operation: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
propogate: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
%2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
%3 = arith.mulf %2, %2 : tensor<1x1024xf32>
%4 = mgmt.relinearize %3 : tensor<1x1024xf32>
affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
%2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
%3 = arith.mulf %2, %2 : tensor<1x1024xf32>
%4 = mgmt.relinearize %3 : tensor<1x1024xf32>
affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 0 to state 2770
LevelAnalysis: Visiting operation: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
propogate: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
%2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
%3 = arith.mulf %2, %2 : tensor<1x1024xf32>
%4 = mgmt.relinearize %3 : tensor<1x1024xf32>
affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
%2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
%3 = arith.mulf %2, %2 : tensor<1x1024xf32>
%4 = mgmt.relinearize %3 : tensor<1x1024xf32>
affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2771
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2771
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>