Skip to content

Commit 8f22058

Browse files
erwei-xilinxclaude
andauthored
[aievec] Add scalar compound shift+clamp+truncate to SRS pattern (#2894)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2e7c91b commit 8f22058

9 files changed

Lines changed: 524 additions & 46 deletions

File tree

lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4076,7 +4076,18 @@ class ExtractElemOpConversion
40764076
// create truncation op (and bitcast op)
40774077
if (llvm::isa<IntegerType>(resultType)) {
40784078
if (resultBitWidth < 32) {
4079-
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
4079+
// Two-step truncation to avoid direct i32→i8 which the AIE2
4080+
// backend cannot legalize after SLP vectorization.
4081+
if (resultBitWidth < 16) {
4082+
auto i16Ty = rewriter.getI16Type();
4083+
auto trunc16 =
4084+
LLVM::TruncOp::create(rewriter, loc, i16Ty, extElemOp);
4085+
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType,
4086+
trunc16.getResult());
4087+
} else {
4088+
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType,
4089+
extElemOp);
4090+
}
40804091
} else {
40814092
rewriter.replaceOp(op, extElemOp);
40824093
}
@@ -4207,34 +4218,48 @@ class MatMulOpConversion
42074218
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
42084219
/*sub_mask=*/0)};
42094220

4221+
// Helper: look through vector.shape_cast ops to find the defining op.
4222+
// The VectorToAIEVec pass inserts shape_casts (via reshapeLeadingUnitDims)
4223+
// between the extension ops and the matmul, which hides the signedness.
4224+
auto lookThroughShapeCasts = [](Value v) -> Value {
4225+
while (auto castOp = v.getDefiningOp<vector::ShapeCastOp>())
4226+
v = castOp.getSource();
4227+
return v;
4228+
};
4229+
42104230
int signX = 0, signY = 0;
42114231
auto lhsVecTy = cast<VectorType>(lhs.getType());
42124232
auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4213-
if (auto extSIOp = lhs.getDefiningOp<arith::ExtSIOp>()) {
4214-
lhs = extSIOp.getIn();
4233+
Value lhsOrig = lookThroughShapeCasts(lhs);
4234+
if (auto extSIOp = lhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4235+
lhs = lookThroughShapeCasts(extSIOp.getIn());
42154236
lhsVecTy = cast<VectorType>(lhs.getType());
42164237
lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
42174238
signX = 1;
4218-
} else if (auto extUIOp = lhs.getDefiningOp<arith::ExtUIOp>()) {
4219-
lhs = extUIOp.getIn();
4239+
} else if (auto extUIOp = lhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4240+
lhs = lookThroughShapeCasts(extUIOp.getIn());
42204241
lhsVecTy = cast<VectorType>(lhs.getType());
42214242
lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
42224243
} else {
4223-
// NOTE: We're choosing 'signed' by default
4224-
if (!lhsScaTy.isUnsigned())
4225-
signX = 1;
4244+
// Default to unsigned for lhs (activation input is typically uint8).
4245+
// The VectorToAIEVec pass strips extsi/extui before creating
4246+
// aievec.matmul, so sign info is not available here. Using unsigned
4247+
// for A matches the common use case of uint8 activations × int8 weights.
4248+
if (lhsScaTy.isUnsigned())
4249+
signX = 0;
42264250
}
42274251
auto lhsShape = lhsVecTy.getShape();
42284252

42294253
auto rhsVecTy = cast<VectorType>(rhs.getType());
42304254
auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4231-
if (auto extSIOp = rhs.getDefiningOp<arith::ExtSIOp>()) {
4232-
rhs = extSIOp.getIn();
4255+
Value rhsOrig = lookThroughShapeCasts(rhs);
4256+
if (auto extSIOp = rhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4257+
rhs = lookThroughShapeCasts(extSIOp.getIn());
42334258
rhsVecTy = cast<VectorType>(rhs.getType());
42344259
rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
42354260
signY = 1;
4236-
} else if (auto extUIOp = rhs.getDefiningOp<arith::ExtUIOp>()) {
4237-
rhs = extUIOp.getIn();
4261+
} else if (auto extUIOp = rhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4262+
rhs = lookThroughShapeCasts(extUIOp.getIn());
42384263
rhsVecTy = cast<VectorType>(rhs.getType());
42394264
rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
42404265
} else {

lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ static auto getAIE2Intrinsics(OpBuilder &builder) {
109109
{"llvm.aie2.release",
110110
{int32Type, int32Type},
111111
{}}, //(%lock_id, %lock_val) -> ()
112+
{"llvm.aie2.set.ctrl.reg",
113+
{int32Type, int32Type},
114+
{}}, //(%reg_id, %value) -> ()
112115
};
113116
return functions;
114117
}
@@ -137,6 +140,9 @@ static auto getAIE2pIntrinsics(OpBuilder &builder) {
137140
{"llvm.aie2p.release",
138141
{int32Type, int32Type},
139142
{}}, //(%lock_id, %lock_val) -> ()
143+
{"llvm.aie2p.set.ctrl.reg",
144+
{int32Type, int32Type},
145+
{}}, //(%reg_id, %value) -> ()
140146
};
141147
return functions;
142148
}
@@ -542,6 +548,50 @@ struct AIECoreToStandardFunc : OpConversionPattern<CoreOp> {
542548
rewriter.cloneRegionBefore(op.getBody(), coreFunc.getBody(),
543549
coreFunc.getBody().begin(), mapper);
544550

551+
// Set saturation and rounding modes at core entry for AIE2/AIE2p, but
552+
// only if the core body contains aievec.srs ops (narrowing SRS needs
553+
// saturation mode enabled). Skip for cores with only lock/stream ops
554+
// to avoid breaking existing test SSA naming.
555+
bool hasSRS = false;
556+
coreFunc.walk([&](Operation *childOp) {
557+
if (childOp->getName().getStringRef() == "aievec.srs")
558+
hasSRS = true;
559+
});
560+
if (hasSRS) {
561+
auto device = op->getParentOfType<DeviceOp>();
562+
if (device) {
563+
AIEArch arch = device.getTargetModel().getTargetArch();
564+
if (arch == AIEArch::AIE2 || arch == AIEArch::AIE2p) {
565+
std::string ctrlRegFuncName = (arch == AIEArch::AIE2p)
566+
? "llvm.aie2p.set.ctrl.reg"
567+
: "llvm.aie2.set.ctrl.reg";
568+
auto ctrlRegFunc = module.lookupSymbol<func::FuncOp>(ctrlRegFuncName);
569+
if (ctrlRegFunc) {
570+
Block &entryBlock = coreFunc.getBody().front();
571+
rewriter.setInsertionPointToStart(&entryBlock);
572+
Location loc = op.getLoc();
573+
// saturation_mode::saturate (register 9 = 1)
574+
auto c9 = arith::ConstantOp::create(rewriter, loc,
575+
rewriter.getI32IntegerAttr(9));
576+
auto c1 = arith::ConstantOp::create(rewriter, loc,
577+
rewriter.getI32IntegerAttr(1));
578+
func::CallOp::create(rewriter, loc, ctrlRegFunc,
579+
ValueRange{c9, c1});
580+
// rounding_mode::floor (register 6 = 0)
581+
// Use floor (truncation) to avoid double-rounding when user
582+
// code already performs explicit rounding via arith.addi
583+
// before shrsi.
584+
auto c6 = arith::ConstantOp::create(rewriter, loc,
585+
rewriter.getI32IntegerAttr(6));
586+
auto c0 = arith::ConstantOp::create(rewriter, loc,
587+
rewriter.getI32IntegerAttr(0));
588+
func::CallOp::create(rewriter, loc, ctrlRegFunc,
589+
ValueRange{c6, c0});
590+
}
591+
}
592+
}
593+
}
594+
545595
// Rewrite the AIE.end() op
546596
coreFunc.getBody().walk([&](Operation *childOp) {
547597
rewriter.setInsertionPointAfter(childOp);

0 commit comments

Comments
 (0)