-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Labels
Milestone
Description
When running llama 70b benchmark on qb/lb, graph retracing is triggered on each iteration. This was observed during investigation of #2897.
Namely, during execution of optimized_mod in torch_xla we enter extract_graph_helper() on each iteration of forward. Relevant snippet - from dynamo_bridge.py:
# If input sharding has changed from the previous program, dynamo current can
# not detect this. It will mistakenly believe the program is the same. We need
# to retrace it here.
if xr.is_spmd():
# if the input sharding was the same for skip_checking_input_sharding_threashold times
# we will skip checking the input sharding since it can be expensive.
if skip_checking_input_sharding_threashold > 0:
if torch_xla._XLAC._get_xla_sharding_specs(
xla_args_tensor_only) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_out_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler, xla_args_need_update) = extract_graph_helper(
xla_model, sym_constants_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
else:
skip_checking_input_sharding_threashold -= 1There is something which triggers the condition that the new sharding specs are different from what we've already cached.
Reactions are currently unavailable