@@ -2818,10 +2818,15 @@ struct PadConverter : public ConvertHloOpToTensorRTPattern<stablehlo::PadOp> {
2818
2818
auto padLowHighSum = trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
2819
2819
loc, targetTrtMajorVersion, shapeTensorType, padLowConst, padHighConst,
2820
2820
tensorrt::ElementWiseOperation::kSUM );
2821
+ if (!padLowHighSum)
2822
+ return failure ();
2821
2823
Value size = padLowHighSum.getResult ();
2822
- size = trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
2824
+ auto sumWithResult = trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
2823
2825
loc, targetTrtMajorVersion, shapeTensorType, size, shape.getResult (),
2824
2826
tensorrt::ElementWiseOperation::kSUM );
2827
+ if (!sumWithResult)
2828
+ return failure ();
2829
+ size = sumWithResult.getResult ();
2825
2830
2826
2831
SmallVector<int32_t > stride (inputType.getRank (), 1 );
2827
2832
return trtRewriter.checkAndReplaceOpWithNewOp <tensorrt::SliceOp>(
@@ -3858,7 +3863,7 @@ struct ConvertScatterToTensorRTScatterElements
3858
3863
if (!constOneTuple)
3859
3864
return failure ();
3860
3865
3861
- Value newIndices = trtRewriter.checkAndCreate <tensorrt::LinspaceOp>(
3866
+ auto newIndices = trtRewriter.checkAndCreate <tensorrt::LinspaceOp>(
3862
3867
op->getLoc (), targetTrtMajorVersion,
3863
3868
newUpdateType.clone (rewriter.getI32Type ()), Value (), startIndex,
3864
3869
constOneTuple, FloatAttr (), FloatAttr ());
@@ -3884,7 +3889,7 @@ struct ConvertScatterToTensorRTScatterElements
3884
3889
auto newOp = trtRewriter.checkAndCreate <tensorrt::ScatterElementsOp>(
3885
3890
op->getLoc (), targetTrtMajorVersion,
3886
3891
/* data*/ convertToI32 (adaptor.getInputs ().front ()),
3887
- /* indices*/ newIndices,
3892
+ /* indices*/ newIndices. getResult () ,
3888
3893
/* updates*/ convertToI32 (newUpdates),
3889
3894
/* axis*/ rewriter.getI64IntegerAttr (axis));
3890
3895
if (!newOp)
@@ -3894,7 +3899,8 @@ struct ConvertScatterToTensorRTScatterElements
3894
3899
auto newOp = trtRewriter.checkAndCreate <tensorrt::ScatterElementsOp>(
3895
3900
op->getLoc (), targetTrtMajorVersion,
3896
3901
/* data*/ adaptor.getInputs ().front (),
3897
- /* indices*/ newIndices, /* updates*/ newUpdates.getResult (),
3902
+ /* indices*/ newIndices.getResult (),
3903
+ /* updates*/ newUpdates.getResult (),
3898
3904
/* axis*/ rewriter.getI64IntegerAttr (axis));
3899
3905
if (!newOp)
3900
3906
return failure ();
@@ -4327,24 +4333,32 @@ struct DynamicUpdateSliceToConcatConverter
4327
4333
// start and shape to be the values appropriate for !hasNonZeroUpdateStart
4328
4334
// (static case). We will update them in the condition block.
4329
4335
// Calculate the slice start = update offset + update size.
4330
- TypedValue<RankedTensorType> concatDimOffset =
4331
- trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
4332
- loc, targetTrtMajorVersion, updateStartOffset,
4333
- tensorrt::createConstShapeTensor (
4334
- rewriter, loc,
4335
- {static_cast <int32_t >(updateType.getDimSize (*concatAxis))}),
4336
- tensorrt::ElementWiseOperation::kSUM );
4336
+ auto sliceStart = trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
4337
+ loc, targetTrtMajorVersion, updateStartOffset,
4338
+ tensorrt::createConstShapeTensor (
4339
+ rewriter, loc,
4340
+ {static_cast <int32_t >(updateType.getDimSize (*concatAxis))}),
4341
+ tensorrt::ElementWiseOperation::kSUM );
4342
+ if (!sliceStart)
4343
+ return failure ();
4344
+ TypedValue<RankedTensorType> concatDimOffset = sliceStart.getResult ();
4345
+
4337
4346
TypedValue<RankedTensorType> endOffset = tensorrt::scatterShapeTensor (
4338
4347
rewriter, loc, SmallVector<int64_t >(updateType.getRank (), 0 ),
4339
4348
*concatAxis, concatDimOffset);
4340
4349
// Calculate the slice size = result shape - update offset.
4341
- TypedValue<RankedTensorType> finalPartDimSize =
4350
+ auto finalPartDimSizeOp =
4342
4351
trtRewriter.checkAndCreate <tensorrt::ElementWiseOp>(
4343
4352
loc, targetTrtMajorVersion,
4344
4353
tensorrt::createConstShapeTensor (
4345
4354
rewriter, loc,
4346
4355
{static_cast <int32_t >(resultType.getDimSize (*concatAxis))}),
4347
4356
concatDimOffset, tensorrt::ElementWiseOperation::kSUB );
4357
+ if (!finalPartDimSizeOp)
4358
+ return failure ();
4359
+ TypedValue<RankedTensorType> finalPartDimSize =
4360
+ finalPartDimSizeOp.getResult ();
4361
+
4348
4362
TypedValue<RankedTensorType> endShape = tensorrt::scatterShapeTensor (
4349
4363
rewriter, loc, resultType.getShape (), *concatAxis, finalPartDimSize);
4350
4364
0 commit comments