Skip to content

Commit 1b0ba09

Browse files
mengluy0125facebook-github-bot
authored andcommitted
[Inductor][Optimus] Fix group fusion stride layout
Summary: X-link: pytorch/benchmark#2442 context: https://fb.workplace.com/groups/1075192433118967/permalink/1401282167176657/ moving the changes to the group gemm op has compilation errors, see details in D55606636 Test Plan: # local reproduce ``` CUDA_LAUNCH_BLOCKING=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split-group --model_type "afoc" --flow_id 544109991 ``` Counter({'pattern_matcher_nodes': 1215, 'pattern_matcher_count': 1090, 'normalization_pass': 430, 'remove_split_with_size_one_pass': 416, 'batch_aten_mul': 13, 'scmerge_split_sections_removed': 11, 'scmerge_cat_removed': 5, 'scmerge_cat_added': 4, 'batch_linear_post_grad': 4, 'scmerge_split_removed': 3, 'batch_aten_sub': 2, 'batch_layernorm': 1, 'group_linear': 1}) ``` CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode group-batch-split --model_type "cmf_shrink" --flow_id 587303213 ``` P1551948670 Counter({'pattern_matcher_nodes': 2244, 'pattern_matcher_count': 1738, 'normalization_pass': 404, 'extern_calls': 370, 'benchmarking.TritonBenchmarker.benchmark_gpu': 293, 'remove_split_with_size_one_pass': 269, 'merge_splits_pass': 74, 'normalization_aten_pass': 56, 'batch_aten_mul': 11, 'fxgraph_cache_miss': 10, 'group_linear': 9, 'scmerge_split_sections_removed': 5, 'scmerge_split_removed': 4, 'scmerge_cat_removed': 4, 'unbind_stack_pass': 4, 'batch_sigmoid': 2, 'batch_linear': 2, 'move_reshape_out_of_split_stack_pass': 2, 'batch_aten_sub': 2, 'batch_aten_add': 2, 'batch_layernorm': 1, 'scmerge_split_added': 1, 'scmerge_cat_added': 1, 'split_stack_to_cats_pass': 1, 'split_cat_to_slices_pass': 1, 'benchmarking.TritonBenchmarker.triton_do_bench': 1, 'batch_relu': 1}) # e2e ### AFOC baseline: f545589474 proposal: f545589302 {F1474302182} ### cmf shrink ads_dper3:0e442d2994ad1421377489d53ef99593 training_platform:be4b7015f1582fb1760bd72cf83ff38d baseline f635512197 baseline + group_fusion f635975547 The group fusion can be enabled but has qps regression by using group fusion. Differential Revision: D61888433
1 parent 40de63b commit 1b0ba09

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

torch/_dynamo/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -3242,3 +3242,9 @@ def record(cls):
32423242
finally:
32433243
if config.record_compile_time_instruction_count:
32443244
cls.end()
3245+
3246+
3247+
def realize_inputs(inputs: List[torch.fx.Node]):
3248+
for inp in inputs:
3249+
if isinstance(inp, torch.fx.node.Node):
3250+
inp.meta["inductor_realize_to_strides"] = True

torch/_inductor/fx_passes/decompose_mem_bound_mm.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
from torch import Tensor
7-
from torch._dynamo.utils import counters
7+
from torch._dynamo.utils import counters, realize_inputs
88

99
from .. import config
1010
from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
@@ -33,12 +33,6 @@ def check_device(a: Tensor, b: Tensor) -> bool:
3333
return a.is_cuda and b.is_cuda
3434

3535

36-
def realize_inputs(inputs: List[torch.fx.Node]):
37-
for inp in inputs:
38-
if isinstance(inp, torch.fx.node.Node):
39-
inp.meta["inductor_realize_to_strides"] = True
40-
41-
4236
def should_decompose_bmm(mat1, mat2) -> bool:
4337
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
4438
mat1 = mat1.meta["val"]

torch/_inductor/fx_passes/group_batch_fusion.py

+42-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
import torch
20-
from torch._dynamo.utils import counters, optimus_scuba_log
20+
from torch._dynamo.utils import counters, optimus_scuba_log, realize_inputs
2121
from torch._utils_internal import upload_graph
2222
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
2323

@@ -299,6 +299,31 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
299299

300300
@register_fusion("group_linear", pre_grad=False)
301301
class GroupLinearFusion(GroupFusion):
302+
def get_stride_type(self, node):
303+
node_shape = node.meta["tensor_meta"].shape # type: ignore[union-attr]
304+
305+
def col_major_stride():
306+
return (
307+
node.meta["tensor_meta"].stride[0] == 1
308+
and node.meta["tensor_meta"].stride[1] > 1
309+
and node.meta["tensor_meta"].stride[1] == node_shape[0]
310+
)
311+
312+
def row_major_stride():
313+
return (
314+
node.meta["tensor_meta"].stride[1] == 1
315+
and node.meta["tensor_meta"].stride[0] > 1
316+
and node.meta["tensor_meta"].stride[0] == node_shape[1]
317+
)
318+
319+
stride = None
320+
if row_major_stride():
321+
stride = "row"
322+
if col_major_stride():
323+
stride = "col"
324+
325+
return stride
326+
302327
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
303328
input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
304329
weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
@@ -331,15 +356,28 @@ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
331356
if CallFunctionVarArgs(aten.mm.default).match(
332357
node
333358
) and self._mm_node_can_be_fused(node):
334-
group_key = ("group_linear", True)
359+
# don't allow inductor lowering to change the stride for the nodes
360+
realize_inputs([node.args[0], node.args[1]]) # type: ignore[possibly-undefined]
361+
input_stride = self.get_stride_type(node.args[0])
362+
weight_stride = self.get_stride_type(node.args[1])
363+
group_key = ("group_linear", str(input_stride), str(weight_stride))
335364
elif CallFunctionVarArgs(aten.addmm.default).match(
336365
node
337366
) and self._addmm_node_can_be_fused(node):
367+
# don't allow inductor lowering to change the stride for the nodes
368+
realize_inputs([node.args[0], node.args[1], node.args[2]]) # type: ignore[possibly-undefined]
369+
input_stride = self.get_stride_type(node.args[1])
370+
weight_stride = self.get_stride_type(node.args[2])
338371
bias = node.args[0]
339-
group_key = ("group_linear", bias is None)
372+
group_key = (
373+
"group_linear",
374+
bias is None,
375+
str(input_stride),
376+
str(weight_stride),
377+
) # type: ignore[assignment]
340378
else:
341379
group_key = None
342-
return group_key
380+
return group_key # type: ignore[return-value]
343381

344382
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
345383
group_inputs = []

0 commit comments

Comments
 (0)