|
11 | 11 |
|
12 | 12 | # pyre-unsafe
|
13 | 13 |
|
| 14 | +import copy |
14 | 15 | import math
|
15 | 16 | from operator import neg
|
16 | 17 | from typing import cast, Dict, Iterable, Sequence, Set, Tuple
|
|
35 | 36 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet
|
36 | 37 | from executorch.exir.dialects._ops import ops as exir_ops
|
37 | 38 | from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
|
| 39 | +from executorch.exir.dim_order_utils import get_memory_format |
38 | 40 | from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
|
| 41 | +from executorch.exir.passes.dim_order_ops_registry import ( |
| 42 | + DimOrderOpsMap, |
| 43 | + MemoryFormatOpsMap, |
| 44 | +) |
39 | 45 | from torch._subclasses import FakeTensor
|
40 | 46 | from torch.fx.node import Argument
|
41 | 47 |
|
@@ -1799,6 +1805,62 @@ def call_operator(
|
1799 | 1805 | )
|
1800 | 1806 |
|
1801 | 1807 |
|
| 1808 | +@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
| 1809 | +class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): |
| 1810 | + """ |
| 1811 | + dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. |
| 1812 | + If the dim order is sequential, we don't need the extra work with strides and |
| 1813 | + can just use to_copy. |
| 1814 | + """ |
| 1815 | + |
| 1816 | + def call_operator( |
| 1817 | + self, |
| 1818 | + op, |
| 1819 | + args: Tuple[Argument, ...], |
| 1820 | + kwargs: Dict[str, Argument], |
| 1821 | + meta: NodeMetadata, |
| 1822 | + ) -> ProxyValue: |
| 1823 | + if op not in DimOrderOpsMap: |
| 1824 | + return super().call_operator(op, args, kwargs, meta) |
| 1825 | + |
| 1826 | + # new kwargs with dim_order, and no memory_format for the new op |
| 1827 | + nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable |
| 1828 | + |
| 1829 | + ndim = None |
| 1830 | + |
| 1831 | + # can always get the shape, assuming rank is specialized |
| 1832 | + |
| 1833 | + # pyre-ignore[16]: `None` has no attribute `to_tensor` |
| 1834 | + if isinstance(args[0], ProxyValue) and args[0].is_tensor(): |
| 1835 | + # pyre-ignore[16]: `None` has no attribute `to_tensor` |
| 1836 | + ndim = args[0].to_tensor().dim() |
| 1837 | + elif isinstance(args[0], torch.Tensor): |
| 1838 | + # pyre-ignore[16]: `None` has no attribute `dim` |
| 1839 | + ndim = args[0].dim() |
| 1840 | + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): |
| 1841 | + # pyre-ignore[6]: Incompatible parameter type |
| 1842 | + ndim = len(args[0]) |
| 1843 | + else: |
| 1844 | + assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" |
| 1845 | + |
| 1846 | + # get the "to" memory format for the EdgeOp |
| 1847 | + default_dim_order = list(range(ndim)) |
| 1848 | + dim_order = nkwargs.pop("dim_order", default_dim_order) |
| 1849 | + |
| 1850 | + # bring back memory format |
| 1851 | + # pyre-ignore[6]: Incompatible parameter type |
| 1852 | + nkwargs["memory_format"] = get_memory_format(dim_order) |
| 1853 | + |
| 1854 | + memory_format_op = MemoryFormatOpsMap[op] |
| 1855 | + |
| 1856 | + return super().call_operator( |
| 1857 | + memory_format_op, |
| 1858 | + args, |
| 1859 | + nkwargs, |
| 1860 | + meta, |
| 1861 | + ) |
| 1862 | + |
| 1863 | + |
1802 | 1864 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
1803 | 1865 | class ReplaceFullLikeWithFullPass(ExportPass):
|
1804 | 1866 | """
|
@@ -2108,4 +2170,5 @@ class CadenceReplaceOpsInGraph:
|
2108 | 2170 | ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
|
2109 | 2171 | ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
|
2110 | 2172 | ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
|
| 2173 | + ReplaceToDimOrderCopyWithToCopyPass, |
2111 | 2174 | ]
|
0 commit comments