Skip to content

Commit fa190de

Browse files
authored
Support torch.amax and torch.amin (#1797)
Support torch.amax and torch.amin
1 parent 9065fdc commit fa190de

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4307,6 +4307,26 @@ def max(context, node):
43074307
context.add(values, torch_name=values_name)
43084308
context.add(indices, torch_name=indices_name)
43094309

4310+
def _add_amax_amin(context, node, reduce_op):
4311+
# mimic functionality from https://pytorch.org/docs/stable/generated/torch.amax.html
4312+
# mimic functionality from https://pytorch.org/docs/stable/generated/torch.amin.html
4313+
assert len(node.outputs) == 1
4314+
4315+
all_inputs = _get_inputs(context, node, expected=[2, 3])
4316+
_input = all_inputs[0]
4317+
dim = [all_inputs[1].val] if type(all_inputs[1].val) == int else [x for x in all_inputs[1].val]
4318+
keepdim = all_inputs[2] if len(all_inputs) == 3 else False
4319+
4320+
context.add(reduce_op(x=_input, axes=dim, keep_dims=keepdim), torch_name=node.outputs[0])
4321+
4322+
@register_torch_op
4323+
def amax(context, node):
4324+
_add_amax_amin(context, node, mb.reduce_max)
4325+
4326+
@register_torch_op
4327+
def amin(context, node):
4328+
_add_amax_amin(context, node, mb.reduce_min)
4329+
43104330

43114331
@register_torch_op
43124332
def argsort(context, node):

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,41 @@ def forward(self, x, y):
28652865
input_shapes, model, backend=backend, compute_unit=compute_unit
28662866
)
28672867

2868+
class TestAMaxAMin(TorchBaseTest):
2869+
@pytest.mark.parametrize(
2870+
"compute_unit, backend, input_shapes, mode, reduce_dim, keepdim",
2871+
itertools.product(
2872+
compute_units,
2873+
backends,
2874+
[
2875+
[(2, 5, 7, 3)],
2876+
[(3, 2, 9)],
2877+
[(1,)],
2878+
],
2879+
["minimum", "maximum"],
2880+
[0, 1, 2, 3, [0, 1], [0, 1, 2], [0, 1, 2, 3]],
2881+
[True, False],
2882+
),
2883+
)
2884+
def test_minimum_maximum(self, compute_unit, backend, input_shapes, mode, reduce_dim, keepdim):
2885+
class TestModel(torch.nn.Module):
2886+
def forward(self, input):
2887+
if type(reduce_dim) == int:
2888+
reduce_dim_clamped = min(input.dim() - 1, reduce_dim)
2889+
else:
2890+
reduce_dim_clamped = reduce_dim[:input.dim()]
2891+
if mode == "minimum":
2892+
return torch.amin(input, reduce_dim_clamped, keepdim)
2893+
elif mode == "maximum":
2894+
return torch.amax(input, reduce_dim_clamped, keepdim)
2895+
else:
2896+
raise ValueError("Unsupported mode: {mode}".format(mode=mode))
2897+
2898+
model = TestModel()
2899+
self.run_compare_torch(
2900+
input_shapes, model, backend=backend, compute_unit=compute_unit
2901+
)
2902+
28682903

28692904
class TestPoolSymbolicInput(TorchBaseTest):
28702905
def test_max_pool(self):

0 commit comments

Comments
 (0)