Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (un)squeeze in XNNPACK delegate via conversion to view #7961

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
Conv1dUnsqueezePass,
)
from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import (
ConvertSqueezeToViewPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
Expand Down Expand Up @@ -67,6 +70,7 @@ def __init__(
DecomposeConcatenate,
RemoveGetItemPass,
Conv1dUnsqueezePass,
ConvertSqueezeToViewPass,
PReLUReshapePass,
ChannelsLastTaggedReshapePass,
TagImplicitQDqPass,
Expand Down
24 changes: 23 additions & 1 deletion backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult

Expand Down Expand Up @@ -77,6 +77,21 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
# is done
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"

def is_view_dim_order_invariant(self, node: torch.fx.Node) -> bool:
# View must be done in NCHW dim order if channel or batch is changed,
# or if rank is not 4.
in_shape = get_input_node(node, 0).meta["val"].shape
out_shape = node.meta["val"].shape

if len(in_shape) != 4 or len(out_shape) != 4:
return False

# Are batch and channel modified? If so, return false.
if in_shape[0] != out_shape[0] or in_shape[1] != out_shape[1]:
return False

return True

def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True

Expand All @@ -93,6 +108,13 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nhwc

def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
# Views depend on whether batch or channel are modified.
if (
node.target == exir_ops.edge.aten.view_copy.default
and not self.is_view_dim_order_invariant(node)
):
return True

return node.target in self.memory_sensitive_ops_nchw

def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
Expand Down
69 changes: 69 additions & 0 deletions backends/xnnpack/_passes/convert_squeeze_to_view_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.utils import check_or_raise
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.fx.experimental.symbolic_shapes import has_free_symbols


class ConvertSqueezeToViewPass(XNNPACKPass):
"""
This pass is used to convert squeeze and unsqueeze nodes into view_copy.
This allows them to be subsequentially lowered as static_reshape ops.
"""

SUPPORTED_OPS = [
exir_ops.edge.aten.squeeze_copy.dim,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
]

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
node_list = list(graph.nodes)
for node in node_list:
if node.op == "call_function":
if node.target in self.SUPPORTED_OPS:
out_shape = node.meta["val"].shape

# Replace up to one dynamic dimension with -1 (inferred dim).
new_shape = []
dynamic_dim_count = 0
for d in out_shape:
if has_free_symbols(d):
new_shape.append(-1)
dynamic_dim_count += 1
else:
new_shape.append(d)

# This constraint should be enforced by the partitioner.
check_or_raise(
dynamic_dim_count <= 1,
"XNN supports only one dynamic dimension",
)

with graph_module.graph.inserting_after(node):
view_node = graph_module.graph.create_node(
"call_function",
target=exir_ops.edge.aten.view_copy.default,
args=(node.args[0], new_shape),
kwargs=node.kwargs,
)

node.replace_all_uses_with(view_node)
graph_module.graph.erase_node(node)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@
op_static_resize_bilinear_2d,
op_sub,
op_to_copy,
op_view_copy,
)
10 changes: 0 additions & 10 deletions backends/xnnpack/operators/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ class OpTCopyDefault(OpSkipOps):
target = "aten.t_copy.default"


@register_node_visitor
class OpViewCopyDefault(OpSkipOps):
"""
currently, do nothing if node is view_copy.default
need to handle this later on, currently view it as one of skip ops
"""

target = "aten.view_copy.default"


@register_node_visitor
class OpSymSizeInt(OpSkipOps):
"""
Expand Down
96 changes: 96 additions & 0 deletions backends/xnnpack/operators/op_view_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGraph,
XNNStaticReshape,
XNode,
)
from executorch.backends.xnnpack.utils.utils import (
check_or_raise,
get_input_node,
PERM_NCHW_TO_NHWC,
)


@register_node_visitor
class ViewCopyVisitor(NodeVisitor):
target = "aten.view_copy.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

input_node = get_input_node(node, 0)

# input
input_id = vals_to_ids[input_node]

# output
output_id = vals_to_ids[node]

# input shape
check_or_raise(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
)

# output shape
check_or_raise(
"val" in node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
)

new_shape = node.args[1]
check_or_raise(
all(isinstance(d, int) for d in new_shape),
"Symbolic reshape parameter is not supported in XNNStaticReshape",
)

# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
new_shape = tuple(d if d != -1 else 0 for d in new_shape)

# Handle NCHW dim order - if this op is in NCHW order, we need to permute the
# view shape correspondingly.
if "XNN_NHWC_NODE" in node.meta:
check_or_raise(len(new_shape) == 4, "Invalid NCHW shape")
new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)]

num_dynamic_dims = sum(1 for d in new_shape if d == 0)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension.",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
num_dims=len(new_shape),
new_shape=new_shape,
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
6 changes: 6 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@
SliceCopyConfig,
SoftmaxConfig,
SquareRootConfig,
SqueezeCopyConfig,
SubConfig,
UnsqueezeCopyConfig,
UpsampleBilinear2dConfig,
ViewCopyConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
BatchNormConfig,
Expand Down Expand Up @@ -98,8 +101,11 @@
SliceCopyConfig,
SoftmaxConfig,
SquareRootConfig,
SqueezeCopyConfig,
SubConfig,
UnsqueezeCopyConfig,
UpsampleBilinear2dConfig,
ViewCopyConfig,
# Quant/Dequant Op Configs
QuantizedPerTensorConfig,
DeQuantizedPerTensorConfig,
Expand Down
79 changes: 79 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
from torch.export import ExportedProgram
from torch.fx.experimental.symbolic_shapes import has_free_symbols

logger = logging.getLogger(__name__)
why = WhyNoPartition(logger=logger)
Expand Down Expand Up @@ -314,6 +315,31 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.max_pool2d.default


class SqueezeCopyConfig(GenericNodePartitionerConfig):
target_name = "squeeze_copy.dims"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.squeeze_copy.default

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's static_reshape only supports 1 dynamic dimension
"""
if not self.check_common_constraints(node, ep):
return False

new_shape = node.meta["val"].shape
dynamic_dim_count = sum(1 for d in new_shape if has_free_symbols(d))
if dynamic_dim_count > 1:
why(node, reason="only a single dynamic dimension is supported")
return False

return True


class UpsampleBilinear2dConfig(GenericNodePartitionerConfig):
target_name = "upsample_bilinear2d.vec"

Expand All @@ -336,6 +362,59 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.upsample_bilinear2d.vec


class UnsqueezeCopyConfig(GenericNodePartitionerConfig):
target_name = "unsqueeze_copy.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.unsqueeze_copy.default

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's static_reshape only supports 1 dynamic dimension
"""
if not self.check_common_constraints(node, ep):
return False

new_shape = node.meta["val"].shape
dynamic_dim_count = sum(
1 for d in new_shape if not isinstance(d, int) and has_free_symbols(d)
)
if dynamic_dim_count > 1:
why(node, reason="only a single dynamic dimension is supported")
return False

return True


class ViewCopyConfig(GenericNodePartitionerConfig):
target_name = "view_copy.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's static_reshape only supports 1 dynamic dimension
"""
if not self.check_common_constraints(node, ep):
return False

new_shape = node.args[1]
if not all(isinstance(n, int) for n in new_shape):
why(node, reason="symbolic reshape is not supported")
return False

dynamic_dim_count = sum(1 for d in new_shape if d == -1)
if dynamic_dim_count > 1:
why(node, reason="only a single dynamic dimension is supported")
return False

return True


class FloorConfig(GenericNodePartitionerConfig):
target_name = "floor.default"

Expand Down
Loading