Skip to content

AnnotateMgmt for programs with loops fails - LevelAnalysis yields infinite loop #1364

Open
@asraa

Description

Root cause: Analyses do not understand region bearing ops (for loops) properly, we need an implementation of visitNonControlFlowArguments most likely.

Reproducer:

$ heir-opt --secret-insert-mgmt-ckks
  func.func @loop(%arg0: !secret.secret<tensor<1x1024xf32>>) -> !secret.secret<tensor<1x1024xf32>> {
    %0 = secret.generic ins(%arg0 : !secret.secret<tensor<1x1024xf32>>) {
    ^body(%input0: tensor<1x1024xf32>):
      %1 = affine.for %arg1 = 0 to 3 iter_args(%arg3 = %input0) -> (tensor<1x1024xf32>) {
        %3 = arith.mulf %arg3, %arg3 : tensor<1x1024xf32>
        affine.yield %3 : tensor<1x1024xf32>
      }
      secret.yield %1 : tensor<1x1024xf32>
    } -> !secret.secret<tensor<1x1024xf32>>
    return %0 : !secret.secret<tensor<1x1024xf32>>
  }

This program has a single ctxt-ctxt multiplication inside, and the result should have a relinearization inserted after the mul so that the block arg and yielded value always have dimension 2. The pass also inserts modreduce before the multiplication even when first mul option is off because it seems that the args are from a mulresult too.

So I think LevelAnalysis, DimensionAnalysis, and MulResultAnalysis all need to be aware of the region bearing operation.

When you run the reproducer, it will run into an infinite loop when it re-runs the solver after the relinearization insertion. I added some debugs to the level analysis and saw it continuining to increment the level of multiplication over and over in the loop.

LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 0 to state 2769
LevelAnalysis: Visiting operation: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
propogate: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>is changed 1 to state 2769
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
  %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
  %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
  %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
  affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
  %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
  %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
  %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
  affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 0 to state 2770
LevelAnalysis: Visiting operation: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
propogate: %4 = mgmt.relinearize %3 : tensor<1x1024xf32>is changed 1 to state 2770
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
  %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
  %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
  %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
  affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: visitNonControlFlowArguments: %1 = affine.for %arg1 = 0 to 3 iter_args(%arg2 = %input0) -> (tensor<1x1024xf32>) {
  %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
  %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
  %4 = mgmt.relinearize %3 : tensor<1x1024xf32>
  affine.yield %4 : tensor<1x1024xf32>
}
LevelAnalysis: Visiting operation: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>
propogate: %2 = mgmt.modreduce %arg2 : tensor<1x1024xf32>is changed 1 to state 2771
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>
propogate: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>is changed 1 to state 2771
LevelAnalysis: Visiting operation: %3 = arith.mulf %2, %2 : tensor<1x1024xf32>

cc @ZenithalHourlyRate

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions