@@ -4076,7 +4076,18 @@ class ExtractElemOpConversion
40764076 // create truncation op (and bitcast op)
40774077 if (llvm::isa<IntegerType>(resultType)) {
40784078 if (resultBitWidth < 32 ) {
4079- rewriter.replaceOpWithNewOp <LLVM::TruncOp>(op, resultType, extElemOp);
4079+ // Two-step truncation to avoid direct i32→i8 which the AIE2
4080+ // backend cannot legalize after SLP vectorization.
4081+ if (resultBitWidth < 16 ) {
4082+ auto i16Ty = rewriter.getI16Type ();
4083+ auto trunc16 =
4084+ LLVM::TruncOp::create (rewriter, loc, i16Ty, extElemOp);
4085+ rewriter.replaceOpWithNewOp <LLVM::TruncOp>(op, resultType,
4086+ trunc16.getResult ());
4087+ } else {
4088+ rewriter.replaceOpWithNewOp <LLVM::TruncOp>(op, resultType,
4089+ extElemOp);
4090+ }
40804091 } else {
40814092 rewriter.replaceOp (op, extElemOp);
40824093 }
@@ -4207,34 +4218,48 @@ class MatMulOpConversion
42074218 /* sub_mul=*/ 0 , /* sub_acc1=*/ 0 , /* sub_acc2=*/ 0 ,
42084219 /* sub_mask=*/ 0 )};
42094220
4221+ // Helper: look through vector.shape_cast ops to find the defining op.
4222+ // The VectorToAIEVec pass inserts shape_casts (via reshapeLeadingUnitDims)
4223+ // between the extension ops and the matmul, which hides the signedness.
4224+ auto lookThroughShapeCasts = [](Value v) -> Value {
4225+ while (auto castOp = v.getDefiningOp <vector::ShapeCastOp>())
4226+ v = castOp.getSource ();
4227+ return v;
4228+ };
4229+
42104230 int signX = 0 , signY = 0 ;
42114231 auto lhsVecTy = cast<VectorType>(lhs.getType ());
42124232 auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType ());
4213- if (auto extSIOp = lhs.getDefiningOp <arith::ExtSIOp>()) {
4214- lhs = extSIOp.getIn ();
4233+ Value lhsOrig = lookThroughShapeCasts (lhs);
4234+ if (auto extSIOp = lhsOrig.getDefiningOp <arith::ExtSIOp>()) {
4235+ lhs = lookThroughShapeCasts (extSIOp.getIn ());
42154236 lhsVecTy = cast<VectorType>(lhs.getType ());
42164237 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType ());
42174238 signX = 1 ;
4218- } else if (auto extUIOp = lhs .getDefiningOp <arith::ExtUIOp>()) {
4219- lhs = extUIOp.getIn ();
4239+ } else if (auto extUIOp = lhsOrig .getDefiningOp <arith::ExtUIOp>()) {
4240+ lhs = lookThroughShapeCasts ( extUIOp.getIn () );
42204241 lhsVecTy = cast<VectorType>(lhs.getType ());
42214242 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType ());
42224243 } else {
4223- // NOTE: We're choosing 'signed' by default
4224- if (!lhsScaTy.isUnsigned ())
4225- signX = 1 ;
4244+ // Default to unsigned for lhs (activation input is typically uint8).
4245+ // The VectorToAIEVec pass strips extsi/extui before creating
4246+ // aievec.matmul, so sign info is not available here. Using unsigned
4247+ // for A matches the common use case of uint8 activations × int8 weights.
4248+ if (lhsScaTy.isUnsigned ())
4249+ signX = 0 ;
42264250 }
42274251 auto lhsShape = lhsVecTy.getShape ();
42284252
42294253 auto rhsVecTy = cast<VectorType>(rhs.getType ());
42304254 auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType ());
4231- if (auto extSIOp = rhs.getDefiningOp <arith::ExtSIOp>()) {
4232- rhs = extSIOp.getIn ();
4255+ Value rhsOrig = lookThroughShapeCasts (rhs);
4256+ if (auto extSIOp = rhsOrig.getDefiningOp <arith::ExtSIOp>()) {
4257+ rhs = lookThroughShapeCasts (extSIOp.getIn ());
42334258 rhsVecTy = cast<VectorType>(rhs.getType ());
42344259 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType ());
42354260 signY = 1 ;
4236- } else if (auto extUIOp = rhs .getDefiningOp <arith::ExtUIOp>()) {
4237- rhs = extUIOp.getIn ();
4261+ } else if (auto extUIOp = rhsOrig .getDefiningOp <arith::ExtUIOp>()) {
4262+ rhs = lookThroughShapeCasts ( extUIOp.getIn () );
42384263 rhsVecTy = cast<VectorType>(rhs.getType ());
42394264 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType ());
42404265 } else {
0 commit comments