Skip to content

Commit

Permalink
fix dynamo inplace copy (#7933)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored Sep 3, 2024
1 parent d69f3a5 commit 989ac69
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
9 changes: 1 addition & 8 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ def test_all_to_all(self, pin_layout):
# Test for collective ops from torch.distributed
class TestDistCollectiveOpsTpu(parameterized.TestCase):

# TODO(zpcore): fix the openxla dynamo issue for inplace copy
@staticmethod
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward

@staticmethod
def _all_reduce(use_dynamo: bool):
met.clear_all()
Expand All @@ -161,9 +156,7 @@ def callable(input):
dtype=torch.float,
device=device)

f = torch.compile(
callable, backend=TestDistCollectiveOpsTpu.my_compiler
) if use_dynamo else callable
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
f(input)
torch_xla.sync()
if not use_dynamo:
Expand Down
29 changes: 17 additions & 12 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,18 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
return extract_compiled_graph_helper(xla_model, xla_args)


def _clear_pending_irs_on_args(args_tensor_only, cloned_args):
# if args_tensor_only has pending IR which means there is a in place operations
# happened. We don't want to execute that operation yet, so we will replace the
# pending IR with the cloned arg.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
args_tensor_only)

for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(args_tensor_only[i], torch.Tensor):
args_tensor_only[i].copy_(cloned_args[i])


def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
all_xla_args_tensor_only):
# below logic will try to partition the fx graph based on the fallback ops.
Expand All @@ -739,18 +751,8 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
print('Dynamo fallback ops are' + str(unsupported_nodes) +
'. Please open a GitHub issue with the above op lowering requests.')

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args_tensor_only)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args_tensor_only[i], torch.Tensor):
all_xla_args_tensor_only[i].copy_(cloned_args[i])
# UnsupportedNodesCollector might trigger in place ops, need to clear them here.
_clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

Expand All @@ -775,6 +777,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# InputCollector might trigger in place ops, need to clear them here.
_clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
Expand Down

0 comments on commit 989ac69

Please sign in to comment.