Commit 0b57581
[pytorch] Disable fast path in MultiheadAttention in Export (pytorch#106824)
Summary:
We are seeing `aten._native_multi_head_attention` op (not in core Aten op set) is left in the exported graph and causes problems in the downstream at runtime.
Two proposed solutions:
1. Disable fast path while tracing to leverage the non-optimized path to get decomp, that way, the blamed op won't show up in the exported graph
2. Add a decomp rule for `aten._native_multi_head_attention`
After discussing with kimishpatel and bdhirsh, #1 is preferred and verified it could immediately unblock the critical model enablement work for PP.
Test Plan: CI
Differential Revision: D48169806
Pull Request resolved: pytorch#106824
Approved by: https://github.com/kimishpatel1 parent 7f9d1ca commit 0b57581
1 file changed
+10
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
895 | 895 | | |
896 | 896 | | |
897 | 897 | | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
898 | 906 | | |
899 | 907 | | |
900 | 908 | | |
| |||
1169 | 1177 | | |
1170 | 1178 | | |
1171 | 1179 | | |
| 1180 | + | |
| 1181 | + | |
1172 | 1182 | | |
1173 | 1183 | | |
1174 | 1184 | | |
| |||
0 commit comments