@@ -308,7 +308,12 @@ LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) {
308
308
309
309
os << " FheUint" << getTfheRustBitWidth (op.getResult ().getType ())
310
310
<< " ::try_encrypt_trivial("
311
- << variableNames->getNameForValue (op.getValue ()) << " ).unwrap();\n " ;
311
+ << variableNames->getNameForValue (op.getValue ());
312
+
313
+ if (op.getValue ().getType ().isSigned ())
314
+ os << " as u" << getTfheRustBitWidth (op.getResult ().getType ());
315
+
316
+ os << " ).unwrap();\n " ;
312
317
return success ();
313
318
}
314
319
@@ -359,7 +364,7 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) {
359
364
// By default, it emits an unsigned integer.
360
365
emitAssignPrefix (op.getResult ());
361
366
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
362
- os << intAttr.getValue ().abs () << " u64 ;\n " ;
367
+ os << intAttr.getValue ().abs () << convertType (op. getType ()) << " ;\n " ;
363
368
} else {
364
369
return op.emitError () << " Unknown constant type " << valueAttr.getType ();
365
370
}
@@ -383,6 +388,17 @@ LogicalResult TfheRustHLEmitter::printBinaryOp(::mlir::Value result,
383
388
std::string_view op) {
384
389
emitAssignPrefix (result);
385
390
391
+ if (auto cteOp = dyn_cast<mlir::arith::ConstantOp>(rhs.getDefiningOp ())) {
392
+ auto intValue =
393
+ cast<IntegerAttr>(cteOp.getValue ()).getValue ().getZExtValue ();
394
+ os << checkOrigin (lhs) << variableNames->getNameForValue (lhs) << " " << op
395
+ << " " << intValue << " u" << cteOp.getType ().getIntOrFloatBitWidth ()
396
+ << " ;\n " ;
397
+ return success ();
398
+ }
399
+
400
+ // Note: arith.constant op requires signless integer types, but here we
401
+ // manually emit an unsigned integer type.
386
402
os << checkOrigin (lhs) << variableNames->getNameForValue (lhs) << " " << op
387
403
<< " " << checkOrigin (rhs) << variableNames->getNameForValue (rhs) << " ;\n " ;
388
404
return success ();
@@ -430,8 +446,8 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) {
430
446
if (failed (emitType (op.getMemref ().getType ().getElementType ()))) {
431
447
return op.emitOpError () << " Failed to get memref element type" ;
432
448
}
433
-
434
449
os << " > = BTreeMap::new();\n " ;
450
+
435
451
return success ();
436
452
}
437
453
@@ -463,12 +479,11 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::LoadOp op) {
463
479
// We assume here that the indices are SSA values (not integer attributes).
464
480
if (isa<BlockArgument>(op.getMemref ())) {
465
481
emitAssignPrefix (op.getResult ());
466
- os << " &" << variableNames->getNameForValue (op.getMemRef ()) << " ["
467
- << flattenIndexExpression (op.getMemRefType (), op.getIndices (),
468
- [&](Value value) {
469
- return variableNames->getNameForValue (value);
470
- })
471
- << " ];\n " ;
482
+ os << " &" << variableNames->getNameForValue (op.getMemRef ());
483
+ for (auto value : op.getIndices ()) {
484
+ os << " [" << variableNames->getNameForValue (value) << " ]" ;
485
+ }
486
+ os << " ;\n " ;
472
487
return success ();
473
488
}
474
489
@@ -586,6 +601,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
586
601
return std::string (prefix) + variableNames->getNameForValue (value) +
587
602
cloneStr;
588
603
}) << " ];\n " ;
604
+
589
605
return success ();
590
606
}
591
607
@@ -662,9 +678,12 @@ FailureOr<std::string> TfheRustHLEmitter::convertType(Type type) {
662
678
}
663
679
auto width = getRustIntegerType (type.getWidth ());
664
680
if (failed (width)) return failure ();
665
- return (type.isUnsigned () ? std::string (" u " ) : " " ) + " i " +
681
+ return (type.isSigned () ? std::string (" i " ) : std::string ( " u " )) +
666
682
std::to_string (width.value ());
667
683
})
684
+ .Case <IndexType>([&](IndexType type) -> FailureOr<std::string> {
685
+ return std::string (" usize" );
686
+ })
668
687
.Case <LookupTableType>(
669
688
[&](auto type) { return std::string (" LookupTableOwned" ); })
670
689
.Default ([&](Type &) { return failure (); });
0 commit comments