Skip to content

fused_rms_norm_backward_strategy doesn't support different input placements #142

@fmassa

Description

@fmassa

Looks like fused_rms_norm_backward_strategy doesn't support inputs with different number of placements.

To reproduce, remove the rms_norm_backward decomposition from

decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)

by adding the following line

decomp_table.pop(torch.ops.aten._fused_rms_norm_backward.default)

The error I obtain is the following:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/examples/example_llama3.py", line 707, in <module>
[rank0]:     with AutoParallel(
[rank0]:          ^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/api.py", line 219, in __enter__
[rank0]:     sharding_optimizer = ShardingOptimizer(
[rank0]:                          ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 139, in __init__
[rank0]:     self.strats = self.build_sharding_metadata()
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 195, in build_sharding_metadata
[rank0]:     strat = get_placement_options(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/utils.py", line 176, in get_placement_options
[rank0]:     out_strat = get_op_strategy(op, op_schema)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/dtensor_util/utils.py", line 228, in get_op_strategy
[rank0]:     return propagator.op_strategy_funcs[op](op_schema)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1141, in fused_rms_norm_bwd_strategy
[rank0]:     return _common_norm_backward_strategy(op_schema, rms_norm=True)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1065, in _common_norm_backward_strategy
[rank0]:     weight_src_spec = _add_target_input_spec(weight_strategy)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 1055, in _add_target_input_spec
[rank0]:     src_spec = strategy.strategies[idx].output_spec
[rank0]:                ~~~~~~~~~~~~~~~~~~~^^^^^
[rank0]: IndexError: list index out of range

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions