You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fusing on by default + Multiply commute pattern rewrite (#3946)
Extending multiply commute to support commuting multiply when scale is
not coming from block argument directly.
Before we only supported this pattern:
```
constant_argument conv2d
| |
| |
+-------multiply-----+
```
In mobilenet there are bunch of layers which have this pattern:
```
constant_argument
|
|
|
transpose
|
|
|
transpose
|
|
|
broadcast conv2d
| |
| |
+-------multiply-----+
```
This PR adds small extension to `Conv2dWithMultiply` which can match scale coming directly from block argument or scale coming from broadcast where subgraph which is input into broadcast is const eval. For example above graph can be commuted since input into graph is constant but something like below can't:
```
constant_argument input
| |
| |
+--------add---------+
|
|
|
broadcast conv2d
| |
| |
+-------multiply-----+
```
```
constant_argument
|
|
|
transpose
|
|
|
transpose conv2d
| |
| |
+-------multiply-----+
```
To check if subraph is fusable we start from`scale` argument in `isCommutable` and we construct [UD chain](https://en.wikipedia.org/wiki/Use-define_chain) and we use it to check if inputs into this subgraph are constants.
When we determine that subgraph is const eval we commute whole subgraph before conv2d and apply reshape like we did before to align channel dim with weight. So resulting graph after commute would become:
```
constant_argument
|
|
|
transpose
|
|
|
transpose
|
|
|
reshape
|
|
|
broadcast weight
| |
| |
| |
multiply--------+
|
|
|
conv2d
```
Or in no broadcast case:
```
constant_argument
|
|
|
reshape weight
| |
| |
| |
multiply--------+
|
|
|
conv2d
```
In addition this PR tags clamp scalar with eltwise unary trait which would enable TM to commute through it.
// This is common pattern throught Resnet. We have conv2d with constant weight, followed by multiply with constant input. This will be commuted through conv2d.
4
4
// Then we fuse add into conv2d with bias and lastly we fuse conv2d and relu into conv2d with activation.
0 commit comments