diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 5c1551c9b94cd..c757046a7b1d5 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -1196,6 +1196,8 @@ def copy_attr(attr_name): copy_attr("do_model_average") copy_attr("need_clip") copy_attr("no_sync") + copy_attr("is_firstly_shared") + assert param.name not in self._slice_params, ( f"Duplicate param.name {param.name} appeared, which will caused precision gap and ckpt bug." ) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 183f3938ca386..bdf9dc6869b77 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -746,10 +746,9 @@ def _synchronize_shared_weights(self): group=comm['group'], ) - for param in comm['layer'].parameters(): - if param.name in comm[ - 'weight_attr' - ] and self.global_rank != min(comm['ranks']): + if self.global_rank != min(comm['ranks']): + for weight_attr in comm['weight_attr']: + param = getattr(comm['layer'], weight_attr) param.is_firstly_shared = False def allreduce_shared_weight_gradients(self): @@ -1013,11 +1012,11 @@ def flush_into_run_function(): flush_into_run_function() if layer.layer_name not in self.shared_layers: self.shared_layers[layer.layer_name] = layer.build_layer() - for param in self.shared_layers[ - layer.layer_name - ].parameters(): - if param.name in layer.shared_weight_attr: - param.is_firstly_shared = True + for weight_attr in layer.shared_weight_attr: + param = getattr( + self.shared_layers[layer.layer_name], weight_attr + ) + param.is_firstly_shared = True if layer.forward_func is None: run_function.append(self.shared_layers[layer.layer_name])