|
11 | 11 |
|
12 | 12 | # pyre-unsafe
|
13 | 13 |
|
| 14 | +import copy |
14 | 15 | import math
|
15 | 16 | from operator import neg
|
16 | 17 | from typing import cast, Dict, Iterable, Sequence, Set, Tuple
|
@@ -1799,6 +1800,48 @@ def call_operator(
|
1799 | 1800 | )
|
1800 | 1801 |
|
1801 | 1802 |
|
| 1803 | +@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
| 1804 | +class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): |
| 1805 | + """ |
| 1806 | + dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. |
| 1807 | + If the dim order is sequential, we don't need the extra work with strides and |
| 1808 | + can just use to_copy. |
| 1809 | + """ |
| 1810 | + |
| 1811 | + def call_operator( |
| 1812 | + self, |
| 1813 | + op, |
| 1814 | + args: Tuple[Argument, ...], |
| 1815 | + kwargs: Dict[str, Argument], |
| 1816 | + meta: NodeMetadata, |
| 1817 | + ) -> ProxyValue: |
| 1818 | + if op != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: |
| 1819 | + return super().call_operator(op, args, kwargs, meta) |
| 1820 | + |
| 1821 | + # new kwargs with dim_order, and no memory_format for the new op |
| 1822 | + nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable |
| 1823 | + |
| 1824 | + assert args[0] == range( |
| 1825 | + # pyre-ignore[16]: `None` has no attribute `to_tensor`. |
| 1826 | + args[0] |
| 1827 | + .to_tensor() |
| 1828 | + .dim() |
| 1829 | + ), "Only sequential dims supported" |
| 1830 | + |
| 1831 | + # remove dim_order from kwargs |
| 1832 | + nkwargs.pop("dim_order", None) |
| 1833 | + |
| 1834 | + # bring back memory format |
| 1835 | + nkwargs["memory_format"] = torch.contiguous_format |
| 1836 | + |
| 1837 | + return super().call_operator( |
| 1838 | + exir_ops.edge.aten._to_copy.default, |
| 1839 | + args, |
| 1840 | + nkwargs, |
| 1841 | + meta, |
| 1842 | + ) |
| 1843 | + |
| 1844 | + |
1802 | 1845 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
1803 | 1846 | class ReplaceFullLikeWithFullPass(ExportPass):
|
1804 | 1847 | """
|
@@ -2108,4 +2151,5 @@ class CadenceReplaceOpsInGraph:
|
2108 | 2151 | ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
|
2109 | 2152 | ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
|
2110 | 2153 | ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
|
| 2154 | + ReplaceToDimOrderCopyWithToCopyPass, |
2111 | 2155 | ]
|
0 commit comments