Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/h-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ concurrency_list="^test_fp8_deep_gemm$|\
^test_scaled_dot_product_attention$|\
^test_compat_scaled_dot_product_attention$|\
^test_flash_attention$|\
^test_batched_gemm$"
^test_batched_gemm$|\
^test_parallel_dygraph_muon$"

cd ${work_dir}/build
tmp_dir=$(mktemp -d)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def __init__(self) -> None:
]
self.sync_param_name: list[str] = ["embedding", "layer_norm", ".b_"]

self.use_muon_sharding: bool = False

self.__lock_attr = True
logger.info("distributed strategy initialized")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AdaptiveLocalSGDOptimizer,
LocalSGDOptimizer,
)
from .muon_sharding_optimizer import MuonShardingOptimizer # noqa: F401
from .pipeline_optimizer import PipelineOptimizer # noqa: F401
from .ps_optimizer import ParameterServerOptimizer # noqa: F401
from .qat_optimizer import QATOptimizer # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
)
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
MuonShardingOptimizer,
)
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
Expand Down Expand Up @@ -284,11 +287,13 @@ def __init__(self, optimizer, hcg, strategy):
split_param = strategy.hybrid_configs[
'sharding_configs'
].split_param
ShardingOptimizer = (
DygraphShardingOptimizerV2
if split_param
else DygraphShardingOptimizer
)
use_muon_sharding = getattr(strategy, "use_muon_sharding", False)
if use_muon_sharding:
ShardingOptimizer = MuonShardingOptimizer
elif split_param:
ShardingOptimizer = DygraphShardingOptimizerV2
else:
ShardingOptimizer = DygraphShardingOptimizer
optimizer = ShardingOptimizer(optimizer, hcg)

self._enable_timer = strategy.hybrid_configs["enable_optimizer_timer"]
Expand Down Expand Up @@ -335,6 +340,7 @@ def __init__(self, optimizer, hcg, strategy):
MixPrecisionOptimizer,
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
MuonShardingOptimizer,
),
)

Expand Down Expand Up @@ -628,7 +634,11 @@ def _hybrid_sync_grad(self, parameter_list):
if self._sharding_enable:
assert isinstance(
self._inner_opt,
(DygraphShardingOptimizer, DygraphShardingOptimizerV2),
(
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
MuonShardingOptimizer,
),
)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
dp_parameter_list = self._inner_opt.filter_parameters(
Expand Down
Loading
Loading