-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
Looks like fused_rms_norm_backward_strategy doesn't support inputs with different number of placements.
To reproduce, remove the rms_norm_backward decomposition from
autoparallel/autoparallel/api.py
Line 62 in 273f54c
| decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) |
by adding the following line
decomp_table.pop(torch.ops.aten._fused_rms_norm_backward.default)The error I obtain is the following:
[rank0]: Traceback (most recent call last):
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/examples/example_llama3.py", line 707, in <module>
[rank0]: with AutoParallel(
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/api.py", line 219, in __enter__
[rank0]: sharding_optimizer = ShardingOptimizer(
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 139, in __init__
[rank0]: self.strats = self.build_sharding_metadata()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 195, in build_sharding_metadata
[rank0]: strat = get_placement_options(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/utils.py", line 176, in get_placement_options
[rank0]: out_strat = get_op_strategy(op, op_schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/dtensor_util/utils.py", line 228, in get_op_strategy
[rank0]: return propagator.op_strategy_funcs[op](op_schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1141, in fused_rms_norm_bwd_strategy
[rank0]: return _common_norm_backward_strategy(op_schema, rms_norm=True)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1065, in _common_norm_backward_strategy
[rank0]: weight_src_spec = _add_target_input_spec(weight_strategy)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1055, in _add_target_input_spec
[rank0]: src_spec = strategy.strategies[idx].output_spec
[rank0]: ~~~~~~~~~~~~~~~~~~~^^^^^
[rank0]: IndexError: list index out of range
Metadata
Metadata
Assignees
Labels
No labels