Skip to content

Commit 278c207

Browse files
perGeorgeARM
andauthored
ArmBackend: Add support for comparison operators (#7885)
Add support for the following operators: * aten.eq.Tensor * aten.le.Tensor * aten.lt.Tensor * aten.ge.Tensor * aten.gt.Tensor Signed-off-by: Georgios Pinitas <[email protected]> Signed-off-by: Per Åstrand <[email protected]> Co-authored-by: Georgios Pinitas <[email protected]>
1 parent 0aa98c7 commit 278c207

13 files changed

+1017
-1
lines changed

backends/arm/operator_support/tosa_supported_operators.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -81,11 +81,16 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
8181
exir_ops.edge.aten.hardtanh.default,
8282
exir_ops.edge.aten.convolution.default,
8383
exir_ops.edge.aten.div.Tensor,
84+
exir_ops.edge.aten.eq.Tensor,
8485
exir_ops.edge.aten.exp.default,
8586
exir_ops.edge.aten.log.default,
8687
exir_ops.edge.aten.linear.default,
8788
exir_ops.edge.aten.split_with_sizes_copy.default,
8889
exir_ops.edge.aten.full.default,
90+
exir_ops.edge.aten.ge.Tensor,
91+
exir_ops.edge.aten.gt.Tensor,
92+
exir_ops.edge.aten.le.Tensor,
93+
exir_ops.edge.aten.lt.Tensor,
8994
exir_ops.edge.aten.mul.Tensor,
9095
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
9196
exir_ops.edge.aten.native_layer_norm.default,

backends/arm/operators/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
op_bmm,
1414
op_cat,
1515
op_conv2d,
16+
op_eq,
1617
op_exp,
1718
op_full,
19+
op_ge,
1820
op_get_item,
21+
op_gt,
1922
op_hardtanh,
23+
op_le,
2024
op_log,
25+
op_lt,
2126
op_max,
2227
op_max_pool2d,
2328
op_min,

backends/arm/operators/op_eq.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class EqualVisitor(NodeVisitor):
25+
target = "aten.eq.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "EQ must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
# Do the equal comparison
53+
tosa_graph.addOperator(
54+
TosaOp.Op().EQUAL,
55+
[input_nodes[0].name, input_nodes[1].name],
56+
output.name,
57+
None,
58+
)

backends/arm/operators/op_ge.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class GreaterEqualVisitor(NodeVisitor):
25+
target = "aten.ge.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "GE must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
tosa_graph.addOperator(
53+
TosaOp.Op().GREATER_EQUAL,
54+
[input_nodes[0].name, input_nodes[1].name],
55+
[output.name],
56+
None,
57+
)

backends/arm/operators/op_gt.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class GreaterThanVisitor(NodeVisitor):
25+
target = "aten.gt.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "GT must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
tosa_graph.addOperator(
53+
TosaOp.Op().GREATER,
54+
[input_nodes[0].name, input_nodes[1].name],
55+
[output.name],
56+
None,
57+
)

backends/arm/operators/op_le.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class LessEqualVisitor(NodeVisitor):
25+
target = "aten.le.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "LE must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
tosa_graph.addOperator(
53+
TosaOp.Op().GREATER_EQUAL,
54+
[input_nodes[1].name, input_nodes[0].name],
55+
[output.name],
56+
None,
57+
)

backends/arm/operators/op_lt.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class LessThanVisitor(NodeVisitor):
25+
target = "aten.lt.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "LT must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
tosa_graph.addOperator(
53+
TosaOp.Op().GREATER,
54+
[input_nodes[1].name, input_nodes[0].name],
55+
[output.name],
56+
None,
57+
)

backends/arm/quantizer/quantization_annotator.py

+15
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,21 @@ def any_or_hardtanh_min_zero(n: Node):
304304
quant_properties.quant_output = _QuantProperty(
305305
0, SharedQuantizationSpec((node.args[0], node))
306306
)
307+
elif node.target in [
308+
torch.ops.aten.eq.Tensor,
309+
torch.ops.aten.ge.Tensor,
310+
torch.ops.aten.gt.Tensor,
311+
torch.ops.aten.le.Tensor,
312+
torch.ops.aten.lt.Tensor,
313+
]:
314+
shared_qspec = SharedQuantizationSpec((node.args[0], node))
315+
quant_properties.quant_inputs = [
316+
_QuantProperty(0, input_act_qspec),
317+
_QuantProperty(
318+
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
319+
),
320+
]
321+
quant_properties.quant_output = None
307322
elif node.target in _parent_shared_qspec:
308323
if not isinstance(node.args[0], Node):
309324
return None

0 commit comments

Comments
 (0)