Skip to content

Unneeded graph retracing in multi-chip scenarios #3302

@pilkicTT

Description

@pilkicTT

When running llama 70b benchmark on qb/lb, graph retracing is triggered on each iteration. This was observed during investigation of #2897.

Namely, during execution of optimized_mod in torch_xla we enter extract_graph_helper() on each iteration of forward. Relevant snippet - from dynamo_bridge.py:

    # If input sharding has changed from the previous program, dynamo current can
    # not detect this. It will mistakenly believe the program is the same. We need
    # to retrace it here.
    if xr.is_spmd():
      # if the input sharding was the same for skip_checking_input_sharding_threashold times
      # we will skip checking the input sharding since it can be expensive.
      if skip_checking_input_sharding_threashold > 0:
        if torch_xla._XLAC._get_xla_sharding_specs(
            xla_args_tensor_only) != xla_args_sharding_spec:
          # update the xla_args with the input with new sharding and retrace
          xla_model.xla_args = args
          (xla_args_sharding_spec, args_and_out_copy, graph_hash,
           arg_index_to_need_update_index, none_remover, graph_input_matcher,
           special_return_handler, xla_args_need_update) = extract_graph_helper(
               xla_model, sym_constants_to_graph_vars)
          skip_checking_input_sharding_threashold = xu.getenv_as(
              'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
        else:
          skip_checking_input_sharding_threashold -= 1

There is something which triggers the condition that the new sharding specs are different from what we've already cached.

Metadata

Metadata

Assignees

Labels

P1multichipMultichip issuesperfPerformance related

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions