Skip to content

Commit 7d0f8ac

Browse files
committed
Add edge padding mode
1 parent 85a8770 commit 7d0f8ac

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,8 +1881,8 @@ def _impl_v2(cls, bb, inputs, attr, params):
18811881
elif pad_mode == "reflect":
18821882
return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT")
18831883
else:
1884-
# TODO(gigiblender) Support edge mode.
1885-
raise NotImplementedError("Pad mode {} not implemented".format(pad_mode))
1884+
# edge mode - replicate border values
1885+
return bb.emit_te(topi.nn.replicate_pad, inputs[0], pad_before, pad_after)
18861886

18871887
@classmethod
18881888
def _impl_v11(cls, bb, inputs, attr, params):
@@ -1911,8 +1911,8 @@ def _impl_v11(cls, bb, inputs, attr, params):
19111911
elif pad_mode == "reflect":
19121912
return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT")
19131913
else:
1914-
# TODO(gigiblender) Support edge mode.
1915-
raise NotImplementedError("Pad mode {} not implemented".format(pad_mode))
1914+
# edge mode - replicate border values
1915+
return bb.emit_te(topi.nn.replicate_pad, inputs[0], pad_before, pad_after)
19161916

19171917

19181918
class Tile(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,6 +2440,8 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0):
24402440
verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0)
24412441
verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0)
24422442
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect")
2443+
verify_pad((2, 3), [1, 1, 1, 1], "edge")
2444+
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "edge")
24432445

24442446

24452447
@pytest.mark.parametrize("dynamic", [True, False])
@@ -2496,6 +2498,8 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0):
24962498
verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0)
24972499
verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0)
24982500
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect")
2501+
verify_pad((2, 3), [1, 1, 1, 1], "edge")
2502+
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "edge")
24992503

25002504

25012505
@pytest.mark.parametrize("fp_arith", [np.float16, np.float32])

0 commit comments

Comments
 (0)