Skip to content

Commit

Permalink
fix: fix LUT input pin orderings to be MSB to LSB in comb truth table…
Browse files Browse the repository at this point in the history
… operations

Previously, Yosys Optimizer's RTLIL importer collected alphabetical pins A, B, C, (...) for truth table operations when importing the optimized RTLIL. A is the LSB and C is the MSB top bit. When converting to comb truth table ops, we place A, B, C as the inputs, while comb's truth table op expects inputs ordered from MSB to LSB. This resulted in correctness issues later in lowerings to CGGI (whose LUT op also expects MSB to LSB) and then tfhe_rust.

PiperOrigin-RevId: 603044828
  • Loading branch information
asraa authored and copybara-github committed Jan 31, 2024
1 parent 6293a83 commit 0df1895
Show file tree
Hide file tree
Showing 19 changed files with 207 additions and 66 deletions.
45 changes: 45 additions & 0 deletions include/Dialect/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Built-in common HEIR definitions

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"HEIRInterfaces.h",
],
)

td_library(
name = "td_files",
srcs = [
"HEIRInterfaces.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../.."],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
],
)

gentbl_cc_library(
name = "interfaces_inc_gen",
tbl_outs = [
(
["-gen-op-interface-decls"],
"HEIRInterfaces.h.inc",
),
(
["-gen-op-interface-defs"],
"HEIRInterfaces.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "HEIRInterfaces.td",
deps = [
":td_files",
],
)
1 change: 1 addition & 0 deletions include/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ gentbl_cc_library(
deps = [
":dialect_inc_gen",
":td_files",
"@heir//include/Dialect:td_files",
"@heir//include/Dialect/LWE/IR:td_files",
"@heir//include/Dialect/Polynomial/IR:td_files",
],
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/CGGI/IR/CGGIOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define HEIR_INCLUDE_DIALECT_CGGI_IR_CGGIOPS_H_

#include "include/Dialect/CGGI/IR/CGGIDialect.h"
#include "include/Dialect/HEIRInterfaces.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
Expand Down
5 changes: 4 additions & 1 deletion include/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef HEIR_INCLUDE_DIALECT_CGGI_IR_CGGIOPS_TD_
#define HEIR_INCLUDE_DIALECT_CGGI_IR_CGGIOPS_TD_

include "include/Dialect/HEIRInterfaces.td"
include "include/Dialect/CGGI/IR/CGGIDialect.td"

include "include/Dialect/Polynomial/IR/PolynomialAttributes.td"
Expand Down Expand Up @@ -53,7 +54,9 @@ class CGGI_LutOp<string mnemonic, list<Trait> traits = []>
Pure,
Commutative,
Elementwise,
Scalarizable
Scalarizable,
DeclareOpInterfaceMethods<LUTOpInterface>

]> {
let results = (outs LWECiphertextLike:$output);
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/Comb/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ td_library(
],
includes = ["../../../.."],
deps = [
"@heir//include/Dialect:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/Comb/IR/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define HEIR_INCLUDE_DIALECT_COMB_COMBOPS_H

#include "include/Dialect/Comb/IR/CombDialect.h"
#include "include/Dialect/HEIRInterfaces.h"
#include "mlir/include/mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
Expand Down
4 changes: 3 additions & 1 deletion include/Dialect/Comb/IR/Combinational.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#ifndef HEIR_INCLUDE_DIALECT_COMB_COMBINATIONAL_TD
#define HEIR_INCLUDE_DIALECT_COMB_COMBINATIONAL_TD

include "include/Dialect/HEIRInterfaces.td"

include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
Expand Down Expand Up @@ -270,7 +272,7 @@ def MuxOp : CombOp<"mux",

}

def TruthTableOp : CombOp<"truth_table", [Pure]> {
def TruthTableOp : CombOp<"truth_table", [Pure, DeclareOpInterfaceMethods<LUTOpInterface>]> {
let summary = "Return a true/false based on a lookup table";
let description = [{
```
Expand Down
12 changes: 12 additions & 0 deletions include/Dialect/HEIRInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef HEIR_INCLUDE_DIALECT_HEIRINTERFACES_H_
#define HEIR_INCLUDE_DIALECT_HEIRINTERFACES_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// Pull in HEIR interfaces
#include "include/Dialect/HEIRInterfaces.h.inc"

#endif // HEIR_INCLUDE_DIALECT_HEIRINTERFACES_H_
14 changes: 14 additions & 0 deletions include/Dialect/HEIRInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
include "mlir/IR/OpBase.td"

def LUTOpInterface : OpInterface<"LUTOpInterface"> {
let description = [{
This is an example interface definition.
}];

let methods = [
InterfaceMethod<
"Gets lookup table inputs from most significant bit to least.",
"mlir::ValueRange", "getLookupTableInputs"
>,
];
}
20 changes: 20 additions & 0 deletions lib/Dialect/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Built in HEIR declarations

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "HEIRInterfaces",
srcs = [
"HEIRInterfaces.cpp",
],
hdrs = [
"@heir//include/Dialect:HEIRInterfaces.h",
],
deps = [
"@heir//include/Dialect:interfaces_inc_gen",
"@llvm-project//mlir:IR",
],
)
2 changes: 2 additions & 0 deletions lib/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_library(
name = "Dialect",
srcs = [
"CGGIDialect.cpp",
"CGGIOps.cpp",
],
hdrs = [
"@heir//include/Dialect/CGGI/IR:CGGIAttributes.h",
Expand All @@ -17,6 +18,7 @@ cc_library(
"@heir//include/Dialect/CGGI/IR:attributes_inc_gen",
"@heir//include/Dialect/CGGI/IR:dialect_inc_gen",
"@heir//include/Dialect/CGGI/IR:ops_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/CGGI/IR/CGGIOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "include/Dialect/CGGI/IR/CGGIOps.h"

#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace cggi {

mlir::ValueRange Lut2Op::getLookupTableInputs() {
return mlir::ValueRange{getB(), getA()};
}

mlir::ValueRange Lut3Op::getLookupTableInputs() {
return mlir::ValueRange{getC(), getB(), getA()};
}

} // namespace cggi
} // namespace heir
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/Comb/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"@heir//include/Dialect/Comb/IR:enum_inc_gen",
"@heir//include/Dialect/Comb/IR:ops_inc_gen",
"@heir//include/Dialect/Comb/IR:type_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:FunctionInterfaces",
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Comb/IR/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ LogicalResult TruthTableOp::verify() {
return success();
}

mlir::ValueRange TruthTableOp::getLookupTableInputs() {
return mlir::ValueRange{getInputs()};
}

} // namespace comb
} // namespace heir
} // namespace mlir
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/HEIRInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Block clang-format from thinking the header is unused
// IWYU pragma: begin_keep
#include "include/Dialect/HEIRInterfaces.h"
// IWYU pragma: end_keep

#include "include/Dialect/HEIRInterfaces.cpp.inc"
9 changes: 8 additions & 1 deletion lib/Transforms/YosysOptimizer/LUTImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ SmallVector<Yosys::RTLIL::SigSpec, 4> LUTImporter::getInputs(
}
inputs.push_back(conn.second);
}
return inputs;
// Alphabetical order gives LSB to MSB, but LUT operations order their inputs
// from MSB to LSB.
SmallVector<Yosys::RTLIL::SigSpec, 4> reversed;
reversed.reserve(inputs.size());
for (unsigned i = 0; i < inputs.size(); i++) {
reversed.push_back(inputs[inputs.size() - i - 1]);
}
return reversed;
}

Yosys::RTLIL::SigSpec LUTImporter::getOutput(Yosys::RTLIL::Cell *cell) const {
Expand Down
10 changes: 6 additions & 4 deletions tests/tfhe_rust/end_to_end/src/main_add_one.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use clap::Parser;
use tfhe::shortint::*;
use tfhe::shortint::parameters::get_parameters_from_message_and_carry;
use tfhe::shortint::*;

mod fn_under_test;

Expand Down Expand Up @@ -33,20 +33,22 @@ pub fn decrypt(ciphertexts: &[Ciphertext], client_key: &ClientKey) -> u8 {
let mut accum = 0u8;
for (i, ct) in ciphertexts.iter().enumerate() {
let bit = client_key.decrypt(ct);
accum |= (bit as u8) << i;
// TODO(403): Fix the ordering of the output bits in Yosys Optimizer
accum |= (bit as u8) << (7 - i);
}
accum
}

fn main() {
let flags = Args::parse();
let parameters = get_parameters_from_message_and_carry((1 << flags.message_bits) - 1, flags.carry_bits);
let parameters =
get_parameters_from_message_and_carry((1 << flags.message_bits) - 1, flags.carry_bits);
let (client_key, server_key) = tfhe::shortint::gen_keys(parameters);

let ct_1 = encrypt(flags.input1.into(), &client_key);

let result = fn_under_test::fn_under_test(&server_key, &ct_1);
let output = decrypt(&result, &client_key);

println!("{:?}", output);
println!("{:08b}", output);
}
90 changes: 31 additions & 59 deletions tests/tfhe_rust/end_to_end/test_add_one.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>

// CHECK: 3
// CHECK: 00000011
module {
func.func @fn_under_test(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
Expand All @@ -17,63 +18,34 @@
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%extracted = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%0 = lwe.encode %true {encoding = #encoding} : i1 to !pt_ty
%1 = lwe.trivial_encrypt %0 : !pt_ty to !ct_ty
%2 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%3 = lwe.trivial_encrypt %2 : !pt_ty to !ct_ty
%4 = cggi.lut3(%extracted, %1, %3) {lookup_table = 8 : ui8} : !ct_ty
%extracted_0 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%5 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%6 = lwe.trivial_encrypt %5 : !pt_ty to !ct_ty
%7 = cggi.lut3(%4, %extracted_0, %6) {lookup_table = 150 : ui8} : !ct_ty
%8 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%9 = lwe.trivial_encrypt %8 : !pt_ty to !ct_ty
%10 = cggi.lut3(%4, %extracted_0, %9) {lookup_table = 23 : ui8} : !ct_ty
%extracted_1 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%11 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%12 = lwe.trivial_encrypt %11 : !pt_ty to !ct_ty
%13 = cggi.lut3(%10, %extracted_1, %12) {lookup_table = 43 : ui8} : !ct_ty
%extracted_2 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%14 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%15 = lwe.trivial_encrypt %14 : !pt_ty to !ct_ty
%16 = cggi.lut3(%13, %extracted_2, %15) {lookup_table = 43 : ui8} : !ct_ty
%extracted_3 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%17 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%18 = lwe.trivial_encrypt %17 : !pt_ty to !ct_ty
%19 = cggi.lut3(%16, %extracted_3, %18) {lookup_table = 43 : ui8} : !ct_ty
%extracted_4 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%20 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%21 = lwe.trivial_encrypt %20 : !pt_ty to !ct_ty
%22 = cggi.lut3(%19, %extracted_4, %21) {lookup_table = 43 : ui8} : !ct_ty
%extracted_5 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%23 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%24 = lwe.trivial_encrypt %23 : !pt_ty to !ct_ty
%25 = cggi.lut3(%22, %extracted_5, %24) {lookup_table = 105 : ui8} : !ct_ty
%26 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%27 = lwe.trivial_encrypt %26 : !pt_ty to !ct_ty
%28 = cggi.lut3(%22, %extracted_5, %27) {lookup_table = 43 : ui8} : !ct_ty
%extracted_6 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
%29 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%30 = lwe.trivial_encrypt %29 : !pt_ty to !ct_ty
%31 = cggi.lut3(%28, %extracted_6, %30) {lookup_table = 105 : ui8} : !ct_ty
%32 = lwe.encode %true {encoding = #encoding} : i1 to !pt_ty
%33 = lwe.trivial_encrypt %32 : !pt_ty to !ct_ty
%34 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%35 = lwe.trivial_encrypt %34 : !pt_ty to !ct_ty
%36 = cggi.lut3(%extracted, %33, %35) {lookup_table = 6 : ui8} : !ct_ty
%37 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%38 = lwe.trivial_encrypt %37 : !pt_ty to !ct_ty
%39 = cggi.lut3(%10, %extracted_1, %38) {lookup_table = 105 : ui8} : !ct_ty
%40 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%41 = lwe.trivial_encrypt %40 : !pt_ty to !ct_ty
%42 = cggi.lut3(%13, %extracted_2, %41) {lookup_table = 105 : ui8} : !ct_ty
%43 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%44 = lwe.trivial_encrypt %43 : !pt_ty to !ct_ty
%45 = cggi.lut3(%16, %extracted_3, %44) {lookup_table = 105 : ui8} : !ct_ty
%46 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%47 = lwe.trivial_encrypt %46 : !pt_ty to !ct_ty
%48 = cggi.lut3(%19, %extracted_4, %47) {lookup_table = 105 : ui8} : !ct_ty
%from_elements = tensor.from_elements %31, %25, %48, %45, %42, %39, %7, %36 : tensor<8x!ct_ty>
%0 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%1 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%2 = lwe.trivial_encrypt %1 : !pt_ty to !ct_ty
%3 = lwe.encode %true {encoding = #encoding} : i1 to !pt_ty
%4 = lwe.trivial_encrypt %3 : !pt_ty to !ct_ty
%5 = cggi.lut3(%2, %4, %0) {lookup_table = 8 : ui8} : !ct_ty
%6 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%7 = cggi.lut3(%2, %6, %5) {lookup_table = 150 : ui8} : !ct_ty
%8 = cggi.lut3(%2, %6, %5) {lookup_table = 23 : ui8} : !ct_ty
%9 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%10 = cggi.lut3(%2, %9, %8) {lookup_table = 43 : ui8} : !ct_ty
%11 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%12 = cggi.lut3(%2, %11, %10) {lookup_table = 43 : ui8} : !ct_ty
%13 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%14 = cggi.lut3(%2, %13, %12) {lookup_table = 43 : ui8} : !ct_ty
%15 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%16 = cggi.lut3(%2, %15, %14) {lookup_table = 43 : ui8} : !ct_ty
%17 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%18 = cggi.lut3(%2, %17, %16) {lookup_table = 105 : ui8} : !ct_ty
%19 = cggi.lut3(%2, %17, %16) {lookup_table = 43 : ui8} : !ct_ty
%20 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
%21 = cggi.lut3(%2, %20, %19) {lookup_table = 105 : ui8} : !ct_ty
%22 = cggi.lut3(%2, %4, %0) {lookup_table = 6 : ui8} : !ct_ty
%23 = cggi.lut3(%2, %9, %8) {lookup_table = 105 : ui8} : !ct_ty
%24 = cggi.lut3(%2, %11, %10) {lookup_table = 105 : ui8} : !ct_ty
%25 = cggi.lut3(%2, %13, %12) {lookup_table = 105 : ui8} : !ct_ty
%26 = cggi.lut3(%2, %15, %14) {lookup_table = 105 : ui8} : !ct_ty
%from_elements = tensor.from_elements %21, %18, %26, %25, %24, %23, %7, %22 : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}
}
Loading

0 comments on commit 0df1895

Please sign in to comment.