@@ -443,6 +443,16 @@ Value castSignedIntValueToType(PatternRewriter &rewriter, Location loc, Value v,
443443 return v;
444444}
445445
446+ Value castScalarIntToIntLike (PatternRewriter &rewriter, Location loc,
447+ Value scalar, Type targetTy) {
448+ auto elemTy = cast<IntegerType>(getElementType (targetTy));
449+ if (scalar.getType () != elemTy)
450+ scalar = castSignedIntValueToType (rewriter, loc, scalar, elemTy);
451+ if (isa<ShapedType>(targetTy))
452+ return tt::SplatOp::create (rewriter, loc, targetTy, scalar);
453+ return scalar;
454+ }
455+
446456Value selectUIntConstantOnSign (PatternRewriter &rewriter, Location loc,
447457 Value signSource, uint64_t signMaskValue,
448458 uint64_t nonNegativeValue,
@@ -674,45 +684,60 @@ Value fpsanSRem(PatternRewriter &rewriter, Location loc, Value num, Value den) {
674684
675685// Modular exponentiation in payload space; this preserves
676686// exp2(a + b) = exp2(a) * exp2(b) under the integer rewrite.
677- Value fpsanExp2FromI32 (PatternRewriter &rewriter, Location loc, Value xI,
687+ Value fpsanExp2FromInt (PatternRewriter &rewriter, Location loc, Value xI,
678688 Type floatTy) {
689+ unsigned bitWidth = getIntBitwidth (xI.getType ());
679690 auto one = getIntConstantLike (rewriter, loc, xI.getType (), 1 );
680691 auto zero = getIntConstantLike (rewriter, loc, xI.getType (), 0 );
681692 auto c = getIntConstantLike (rewriter, loc, xI.getType (), 0xa343836d );
682693
683- Value y = one;
684- for (int i = 0 ; i < 32 ; ++i) {
685- y = arith::MulIOp::create (rewriter, loc, y, y);
686- auto bit = getIntConstantLike (rewriter, loc, xI.getType (),
687- int64_t (1ull << (31 - i)));
688- auto masked = arith::AndIOp::create (rewriter, loc, xI, bit);
689- auto isZero = arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::eq,
690- masked, zero);
691- auto factor = arith::SelectOp::create (rewriter, loc, isZero, one, c);
692- y = arith::MulIOp::create (rewriter, loc, y, factor);
693- }
694-
695- return unembedToFloat (rewriter, loc, y, floatTy);
694+ auto lower =
695+ arith::ConstantOp::create (rewriter, loc, rewriter.getI32IntegerAttr (0 ));
696+ auto upper = arith::ConstantOp::create (rewriter, loc,
697+ rewriter.getI32IntegerAttr (bitWidth));
698+ auto step =
699+ arith::ConstantOp::create (rewriter, loc, rewriter.getI32IntegerAttr (1 ));
700+ auto topBit = arith::ConstantOp::create (
701+ rewriter, loc, rewriter.getI32IntegerAttr (bitWidth - 1 ));
702+ auto loop = scf::ForOp::create (rewriter, loc, lower, upper, step, one);
703+ rewriter.setInsertionPointToStart (loop.getBody ());
704+
705+ Value i = loop.getInductionVar ();
706+ Value y = loop.getRegionIterArgs ()[0 ];
707+ y = arith::MulIOp::create (rewriter, loc, y, y);
708+ Value bitIndex =
709+ arith::SubIOp::create (rewriter, loc, rewriter.getI32Type (), topBit, i);
710+ Value shift = castScalarIntToIntLike (rewriter, loc, bitIndex, xI.getType ());
711+ Value bit = arith::ShLIOp::create (rewriter, loc, one, shift);
712+ auto masked = arith::AndIOp::create (rewriter, loc, xI, bit);
713+ auto isZero = arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::eq,
714+ masked, zero);
715+ auto factor = arith::SelectOp::create (rewriter, loc, isZero, one, c);
716+ y = arith::MulIOp::create (rewriter, loc, y, factor);
717+ scf::YieldOp::create (rewriter, loc, y);
718+ rewriter.setInsertionPointAfter (loop);
719+
720+ return unembedToFloat (rewriter, loc, loop.getResult (0 ), floatTy);
696721}
697722
698723Value fpsanExp2 (PatternRewriter &rewriter, Location loc, Value input) {
699724 auto elemTy = dyn_cast<FloatType>(getElementType (input.getType ()));
700- if (!elemTy || elemTy. getWidth () != 32 )
725+ if (!elemTy)
701726 return Value ();
702- return fpsanExp2FromI32 (rewriter, loc, embedToInt (rewriter, loc, input),
727+ return fpsanExp2FromInt (rewriter, loc, embedToInt (rewriter, loc, input),
703728 input.getType ());
704729}
705730
706731Value fpsanExp (PatternRewriter &rewriter, Location loc, Value input) {
707732 auto elemTy = dyn_cast<FloatType>(getElementType (input.getType ()));
708- if (!elemTy || elemTy. getWidth () != 32 )
733+ if (!elemTy)
709734 return Value ();
710735
711736 auto inputI = embedToInt (rewriter, loc, input);
712737 auto rcpLog2 =
713738 getU32ConstantLike (rewriter, loc, inputI.getType (), 0x236ee9bfu );
714739 auto scaledI = arith::MulIOp::create (rewriter, loc, inputI, rcpLog2);
715- return fpsanExp2FromI32 (rewriter, loc, scaledI, input.getType ());
740+ return fpsanExp2FromInt (rewriter, loc, scaledI, input.getType ());
716741}
717742
718743struct FpSanCosSin {
@@ -735,32 +760,47 @@ FpSanCosSin fpsanCosSinPayload(PatternRewriter &rewriter, Location loc,
735760 auto a = getUIntConstantLike (rewriter, loc, intTy, aValue);
736761 auto b = getUIntConstantLike (rewriter, loc, intTy, bValue);
737762
738- Value c = one;
739- Value s = zero;
740- for (int bit = static_cast <int >(bitWidth) - 1 ; bit >= 0 ; --bit) {
741- Value cc = arith::MulIOp::create (rewriter, loc, c, c);
742- Value ss = arith::MulIOp::create (rewriter, loc, s, s);
743- Value cDouble = arith::SubIOp::create (rewriter, loc, cc, ss);
744- Value cs = arith::MulIOp::create (rewriter, loc, c, s);
745- Value sDouble = arith::MulIOp::create (rewriter, loc, two, cs);
746-
747- Value ac = arith::MulIOp::create (rewriter, loc, a, cDouble);
748- Value bs = arith::MulIOp::create (rewriter, loc, b, sDouble );
749- Value cInc = arith::SubIOp::create (rewriter, loc, ac, bs);
750- Value as = arith::MulIOp::create (rewriter, loc, a, sDouble );
751- Value bc = arith::MulIOp::create (rewriter, loc, b, cDouble);
752- Value sInc = arith::AddIOp::create (rewriter, loc, as, bc);
753-
754- auto bitMask =
755- getUIntConstantLike (rewriter, loc, intTy, uint64_t {1 } << bit);
756- auto masked = arith::AndIOp::create (rewriter, loc, xI, bitMask);
757- auto isZero = arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::eq,
758- masked, zero);
759- c = arith::SelectOp::create (rewriter, loc, isZero, cDouble, cInc);
760- s = arith::SelectOp::create (rewriter, loc, isZero, sDouble , sInc );
761- }
762-
763- return {c, s};
763+ auto lower =
764+ arith::ConstantOp::create (rewriter, loc, rewriter.getI32IntegerAttr (0 ));
765+ auto upper = arith::ConstantOp::create (rewriter, loc,
766+ rewriter.getI32IntegerAttr (bitWidth));
767+ auto step =
768+ arith::ConstantOp::create (rewriter, loc, rewriter.getI32IntegerAttr (1 ));
769+ auto topBit = arith::ConstantOp::create (
770+ rewriter, loc, rewriter.getI32IntegerAttr (bitWidth - 1 ));
771+ SmallVector<Value> initArgs{one, zero};
772+ auto loop = scf::ForOp::create (rewriter, loc, lower, upper, step, initArgs);
773+ rewriter.setInsertionPointToStart (loop.getBody ());
774+
775+ Value bit = loop.getInductionVar ();
776+ Value c = loop.getRegionIterArgs ()[0 ];
777+ Value s = loop.getRegionIterArgs ()[1 ];
778+ Value cc = arith::MulIOp::create (rewriter, loc, c, c);
779+ Value ss = arith::MulIOp::create (rewriter, loc, s, s);
780+ Value cDouble = arith::SubIOp::create (rewriter, loc, cc, ss);
781+ Value cs = arith::MulIOp::create (rewriter, loc, c, s);
782+ Value sDouble = arith::MulIOp::create (rewriter, loc, two, cs);
783+
784+ Value ac = arith::MulIOp::create (rewriter, loc, a, cDouble);
785+ Value bs = arith::MulIOp::create (rewriter, loc, b, sDouble );
786+ Value cInc = arith::SubIOp::create (rewriter, loc, ac, bs);
787+ Value as = arith::MulIOp::create (rewriter, loc, a, sDouble );
788+ Value bc = arith::MulIOp::create (rewriter, loc, b, cDouble);
789+ Value sInc = arith::AddIOp::create (rewriter, loc, as, bc);
790+
791+ Value bitIndex =
792+ arith::SubIOp::create (rewriter, loc, rewriter.getI32Type (), topBit, bit);
793+ Value shift = castScalarIntToIntLike (rewriter, loc, bitIndex, intTy);
794+ Value bitMask = arith::ShLIOp::create (rewriter, loc, one, shift);
795+ auto masked = arith::AndIOp::create (rewriter, loc, xI, bitMask);
796+ auto isZero = arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::eq,
797+ masked, zero);
798+ c = arith::SelectOp::create (rewriter, loc, isZero, cDouble, cInc);
799+ s = arith::SelectOp::create (rewriter, loc, isZero, sDouble , sInc );
800+ scf::YieldOp::create (rewriter, loc, ValueRange{c, s});
801+ rewriter.setInsertionPointAfter (loop);
802+
803+ return {loop.getResult (0 ), loop.getResult (1 )};
764804}
765805
766806Value fpsanCos (PatternRewriter &rewriter, Location loc, Value input) {
0 commit comments