Skip to content

Commit 561ac46

Browse files
authored
changes from apache#8131 (#228)
1 parent 5203272 commit 561ac46

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

python/tvm/relay/op/contrib/tensorrt.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,12 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
733733
if attrs.pad_mode != "constant":
734734
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
735735
return False
736-
if float(attrs.pad_value) != 0.0:
737-
logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value))
736+
if (
737+
not isinstance(args[1], relay.Constant)
738+
or len(args[1].checked_type.shape) != 0
739+
or args[1].data.numpy().item() != 0.0
740+
):
741+
logger.info("nn.pad: pad value is %s but must be 0.0.", args[1])
738742
return False
739743
if len(attrs.pad_width) not in [4, 5]:
740744
logger.info("nn.pad: can only pad 4D or 5D inputs")

src/runtime/contrib/tensorrt/tensorrt_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ class ReshapeOpConverter : public TensorRTOpConverter {
10571057

10581058
class PadOpConverter : public TensorRTOpConverter {
10591059
public:
1060-
PadOpConverter() : TensorRTOpConverter({kTensor}) {}
1060+
PadOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
10611061

10621062
void Convert(TensorRTOpConverterParams* params) const {
10631063
auto input = params->inputs.at(0).tensor;

0 commit comments

Comments
 (0)