Skip to content

Commit c326859

Browse files
authored
[PERF] [llama8B] Add subgraph rewrite rule to fuse Matmul and cast (#1021)
Unfused Graph: ![image](https://github.com/user-attachments/assets/766e5b8d-3882-400d-a398-e513046e5a63) Original Behaviour: ![image](https://github.com/user-attachments/assets/67c52eed-145b-4cdf-afb3-6ffa492a29d5) New Behaviour: ![image](https://github.com/user-attachments/assets/8a3b0588-9548-45f2-a0e9-8329a39c8533) Originally, the CastOp was being left out of the fusion because after a epilogue fusion pass, everything after matmul would be fused and CastOp would be left as a prologue operator. Because prologue fusion is disabled on Matmul, CastOp would never be fused. Fixed this issue by adding a MultiInputCompositeElementWiseOp which combines 2 UnaryOps with different inputs and 1 BinaryElementWiseOp, allowing this shape to be fused with Matmul.
1 parent b4ecb5d commit c326859

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed

python/hidet/graph/ops/arithmetic.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,53 @@ def __init__(self, name: str, args: List[TensorNode], op: Callable[[Any], Any]):
135135
super().__init__(name=name, inputs=list(args), outputs=[out], inverse_map=inverse_map)
136136

137137

138+
class MultiInputCompositeElementwiseTask(Task):
139+
def __init__(
140+
self,
141+
name: str,
142+
x1: TensorNode,
143+
x2: TensorNode,
144+
left_unary_op: Callable[[Any], Any],
145+
right_unary_op: Callable[[Any], Any],
146+
binary_op: Callable[[Any, Any], Any],
147+
attrs=None,
148+
):
149+
def composite_op(binary_op, left_unary_op, right_unary_op, x1, x2):
150+
if left_unary_op is None:
151+
left_unary_op = lambda x1: x1
152+
if right_unary_op is None:
153+
right_unary_op = lambda x2: x2
154+
return binary_op(left_unary_op(x1), right_unary_op(x2))
155+
156+
z_shape = broadcast_shape(x1.shape, x2.shape)
157+
z = compute(
158+
name='z',
159+
shape=z_shape,
160+
fcompute=lambda *indices: composite_op(
161+
binary_op, left_unary_op, right_unary_op, x1.__getitem__(indices), x2.__getitem__(indices)
162+
),
163+
)
164+
165+
inverse_map = {}
166+
for inp, inp_shape in zip([x1, x2], [x1.shape, x2.shape]):
167+
if same_list(inp_shape, z_shape):
168+
inverse_map[inp] = InverseMap.from_lambda(lambda *indices: indices, num_args=len(inp_shape))
169+
elif prod(inp_shape) == prod(z_shape):
170+
inverse_map[inp] = InverseMap.from_lambda(
171+
lambda *indices: [0 for _ in range(len(z_shape) - len(inp_shape))] + list(indices),
172+
num_args=len(inp_shape),
173+
)
174+
# layout := 1:0 means identity layout
175+
# It's obvious that arithmetic operations don't apply layout
176+
# modification on tensors, so the tile mapping function can be
177+
# considered as identity
178+
for _, v in inverse_map.items():
179+
v.tile_mapping = TensorLayout(1)
180+
super().__init__(
181+
name=name, inputs=[x1, x2], outputs=[z], inverse_map=inverse_map, attributes={} if attrs is None else attrs
182+
)
183+
184+
138185
class CompositeElementwiseTask(Task):
139186
def __init__(
140187
self,
@@ -354,6 +401,29 @@ def get_dtype(scalar: Expr):
354401
return inferred_type
355402

356403

404+
class MultiInputCompositeElementwiseOp(Operator):
405+
def __init__(
406+
self,
407+
x1: Tensor,
408+
x2: Tensor,
409+
left_unary_op: UnaryElementwiseOperation,
410+
right_unary_op: UnaryElementwiseOperation,
411+
binary_op: BinaryElementwiseOperation,
412+
):
413+
name = 'multi_input_composite'
414+
for op in [left_unary_op, right_unary_op, binary_op]:
415+
if op is not None:
416+
name += '_' + op.name
417+
attributes = {'left_unary_op': left_unary_op, 'right_unary_op': right_unary_op, 'binary_op': binary_op}
418+
super().__init__(
419+
inputs=[x1, x2],
420+
attributes=attributes,
421+
task=MultiInputCompositeElementwiseTask(
422+
name, input_like(x1, 'x1'), input_like(x2, 'x2'), left_unary_op, right_unary_op, binary_op
423+
),
424+
)
425+
426+
357427
class CompositeElementwiseOp(Operator):
358428
def __init__(
359429
self,
@@ -1202,6 +1272,16 @@ def roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence
12021272
return RollOp(x, shifts, dims).outputs[0]
12031273

12041274

1275+
def composite_multi_input_elementwise(
1276+
x1: Tensor,
1277+
x2: Tensor,
1278+
left_unary_op: UnaryElementwiseOperation,
1279+
right_unary_op: UnaryElementwiseOperation,
1280+
binary_op: BinaryElementwiseOperation,
1281+
) -> Tensor:
1282+
return MultiInputCompositeElementwiseOp(x1, x2, left_unary_op, right_unary_op, binary_op).outputs[0]
1283+
1284+
12051285
# out = binary_op(left_unary_op(x), right_unary_op(x)); This allows more fusion opportunity.
12061286
def composite_elementwise(
12071287
x: Tensor,

python/hidet/graph/ops/transform.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
from typing import List, Optional, Tuple, Union, Sequence
13+
from hidet.graph.ops.arithmetic import UnaryElementwiseOp
1314
from hidet.ir.type import DataType, data_type
1415
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant, logical_or
1516
from hidet.ir.expr import Int
@@ -697,15 +698,9 @@ def run_torch(self):
697698
return result
698699

699700

700-
class CastOp(Operator):
701+
class CastOp(UnaryElementwiseOp):
701702
def __init__(self, x: Tensor, dtype: DataType):
702-
from .arithmetic import UnaryElementwiseTask
703-
704-
super().__init__(
705-
inputs=[x],
706-
attributes={'dtype': dtype},
707-
task=UnaryElementwiseTask('cast', input_like(x, 'x'), op=lambda v: ir_cast(v, dtype)),
708-
)
703+
super().__init__(x, op=lambda v: ir_cast(v, dtype), name='cast', attributes={'dtype': dtype})
709704

710705
def run_torch(self):
711706
x = self.inputs[0]

python/hidet/graph/transforms/graph_patterns/transform_patterns.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
171171
return [x]
172172

173173

174+
class MultiInputCompositeElementwiseLeftRightRewriteRule(SubgraphRewriteRule):
175+
def __init__(self):
176+
super().__init__('binaryOp(unaryOp(x), unaryOp(y)) => compositeOp(x, y)')
177+
self.x1 = TensorPattern()
178+
self.x2 = TensorPattern()
179+
self.y1 = op_pattern(UnaryElementwiseOp, [self.x1])
180+
self.y2 = op_pattern(UnaryElementwiseOp, [self.x2])
181+
self.z = op_pattern(BinaryElementwiseOp, [self.y1, self.y2])
182+
183+
def source(self) -> List[TensorPattern]:
184+
return [self.z]
185+
186+
def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
187+
x1, x2, y1, y2, z = [matched[v] for v in [self.x1, self.x2, self.y1, self.y2, self.z]]
188+
left_unary_op: UnaryElementwiseOperation = y1.op.op
189+
right_unary_op: UnaryElementwiseOperation = y2.op.op
190+
if left_unary_op.name != right_unary_op.name:
191+
return None
192+
binary_op: BinaryElementwiseOperation = z.op.op
193+
out = ops.arithmetic.composite_multi_input_elementwise(x1, x2, left_unary_op, right_unary_op, binary_op)
194+
return [out]
195+
196+
174197
class CompositeElementwiseLeftRightRewriteRule(SubgraphRewriteRule):
175198
def __init__(self):
176199
super().__init__('binaryOp(unaryOp_left(x), unaryOp_right(x)) => compositeOp(x)')
@@ -235,6 +258,7 @@ def transform_patterns():
235258
register_rewrite_rule(FanoutTwoCast())
236259
register_rewrite_rule(FanoutThreeCast())
237260
register_rewrite_rule(DoubleCast())
261+
register_rewrite_rule(MultiInputCompositeElementwiseLeftRightRewriteRule())
238262
register_rewrite_rule(CompositeElementwiseLeftRightRewriteRule())
239263
register_rewrite_rule(CompositeElementwiseLeftRewriteRule())
240264
register_rewrite_rule(CompositeElementwiseRightRewriteRule())

0 commit comments

Comments
 (0)