@@ -1501,9 +1501,19 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
15011501 getTypeConverter ()->convertType (op.getResult ().getType ()));
15021502 auto outputETy = resultTy.getElementType ();
15031503
1504+ auto ensureF32Input = [&](Value input) -> Value {
1505+ auto inputTy = cast<RankedTensorType>(input.getType ());
1506+ if (inputTy.getElementType ().isF32 ())
1507+ return input;
1508+ auto castTy =
1509+ RankedTensorType::get (inputTy.getShape (), rewriter.getF32Type ());
1510+ return tosa::CastOp::create (rewriter, op->getLoc (), castTy, input);
1511+ };
1512+
15041513 // Create a single instance of tosa.argmax.
15051514 // Multiple dims require chained construct.
15061515 auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
1516+ input = ensureF32Input (input);
15071517 auto inputTy = cast<RankedTensorType>(input.getType ());
15081518 auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
15091519 SmallVector<int64_t > outputShapeArr = {};
@@ -1523,7 +1533,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
15231533 makeShapeLLVMCompatible (ArrayRef<int64_t >(outputShapeArr)),
15241534 rewriter.getI32Type ());
15251535 auto reduceDimAttr =
1526- rewriter.getIntegerAttr (rewriter.getI64Type (), reduceDim);
1536+ rewriter.getIntegerAttr (rewriter.getI32Type (), reduceDim);
15271537
15281538 // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
15291539 return tosa::ArgMaxOp::create (
@@ -4692,26 +4702,37 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
46924702
46934703 // To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
46944704 // of the input tensor, which will return indices of input's min values
4705+ auto ensureF32Input = [&](Value input) -> Value {
4706+ auto inputTy = cast<RankedTensorType>(input.getType ());
4707+ if (inputTy.getElementType ().isF32 ())
4708+ return input;
4709+ auto castTy =
4710+ RankedTensorType::get (inputTy.getShape (), rewriter.getF32Type ());
4711+ return tosa::CastOp::create (rewriter, op->getLoc (), castTy, input);
4712+ };
4713+
46954714 Value argMaxOp;
46964715 if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
46974716 Value negateOp =
46984717 tosa::NegateOp::create (rewriter, op->getLoc (), selfType, self);
4718+ Value argInput = ensureF32Input (negateOp);
46994719
47004720 // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47014721 argMaxOp = tosa::ArgMaxOp::create (
47024722 rewriter, op->getLoc (),
47034723 RankedTensorType::get (makeShapeLLVMCompatible (prunedShape),
47044724 indicesElemType),
4705- negateOp , dimAttr, /* nan_mode=*/
4725+ argInput , dimAttr, /* nan_mode=*/
47064726 tosa::NanPropagationModeAttr::get (
47074727 rewriter.getContext (), tosa::NanPropagationMode::PROPAGATE));
47084728 } else {
4729+ Value argInput = ensureF32Input (self);
47094730 // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47104731 argMaxOp = tosa::ArgMaxOp::create (
47114732 rewriter, op->getLoc (),
47124733 RankedTensorType::get (makeShapeLLVMCompatible (prunedShape),
47134734 indicesElemType),
4714- self , dimAttr, /* nan_mode=*/
4735+ argInput , dimAttr, /* nan_mode=*/
47154736 tosa::NanPropagationModeAttr::get (
47164737 rewriter.getContext (), tosa::NanPropagationMode::PROPAGATE));
47174738 }
0 commit comments