|
17 | 17 | )
|
18 | 18 |
|
19 | 19 | import torch
|
20 |
| -from torch._dynamo.utils import counters, optimus_scuba_log |
| 20 | +from torch._dynamo.utils import counters, optimus_scuba_log, realize_inputs |
21 | 21 | from torch._utils_internal import upload_graph
|
22 | 22 | from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
23 | 23 |
|
@@ -299,6 +299,31 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
299 | 299 |
|
300 | 300 | @register_fusion("group_linear", pre_grad=False)
|
301 | 301 | class GroupLinearFusion(GroupFusion):
|
| 302 | + def get_stride_type(self, node): |
| 303 | + node_shape = node.meta["tensor_meta"].shape # type: ignore[union-attr] |
| 304 | + |
| 305 | + def col_major_stride(): |
| 306 | + return ( |
| 307 | + node.meta["tensor_meta"].stride[0] == 1 |
| 308 | + and node.meta["tensor_meta"].stride[1] > 1 |
| 309 | + and node.meta["tensor_meta"].stride[1] == node_shape[0] |
| 310 | + ) |
| 311 | + |
| 312 | + def row_major_stride(): |
| 313 | + return ( |
| 314 | + node.meta["tensor_meta"].stride[1] == 1 |
| 315 | + and node.meta["tensor_meta"].stride[0] > 1 |
| 316 | + and node.meta["tensor_meta"].stride[0] == node_shape[1] |
| 317 | + ) |
| 318 | + |
| 319 | + stride = None |
| 320 | + if row_major_stride(): |
| 321 | + stride = "row" |
| 322 | + if col_major_stride(): |
| 323 | + stride = "col" |
| 324 | + |
| 325 | + return stride |
| 326 | + |
302 | 327 | def _addmm_node_can_be_fused(self, node: torch.fx.Node):
|
303 | 328 | input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
|
304 | 329 | weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
|
@@ -331,15 +356,28 @@ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
|
331 | 356 | if CallFunctionVarArgs(aten.mm.default).match(
|
332 | 357 | node
|
333 | 358 | ) and self._mm_node_can_be_fused(node):
|
334 |
| - group_key = ("group_linear", True) |
| 359 | + # don't allow inductor lowering to change the stride for the nodes |
| 360 | + realize_inputs([node.args[0], node.args[1]]) # type: ignore[list-item, possibly-undefined] |
| 361 | + input_stride = self.get_stride_type(node.args[0]) |
| 362 | + weight_stride = self.get_stride_type(node.args[1]) |
| 363 | + group_key = ("group_linear", str(input_stride), str(weight_stride)) |
335 | 364 | elif CallFunctionVarArgs(aten.addmm.default).match(
|
336 | 365 | node
|
337 | 366 | ) and self._addmm_node_can_be_fused(node):
|
| 367 | + # don't allow inductor lowering to change the stride for the nodes |
| 368 | + realize_inputs([node.args[0], node.args[1], node.args[2]]) # type: ignore[list-item, possibly-undefined] |
| 369 | + input_stride = self.get_stride_type(node.args[1]) |
| 370 | + weight_stride = self.get_stride_type(node.args[2]) |
338 | 371 | bias = node.args[0]
|
339 |
| - group_key = ("group_linear", bias is None) |
| 372 | + group_key = ( |
| 373 | + "group_linear", |
| 374 | + bias is None, |
| 375 | + str(input_stride), |
| 376 | + str(weight_stride), |
| 377 | + ) # type: ignore[assignment] |
340 | 378 | else:
|
341 | 379 | group_key = None
|
342 |
| - return group_key |
| 380 | + return group_key # type: ignore[return-value] |
343 | 381 |
|
344 | 382 | def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
345 | 383 | group_inputs = []
|
|
0 commit comments