Skip to content

Commit be2879c

Browse files
committed
fix SharedLayerDesc lose attr is_firstly_shared
1 parent 0442435 commit be2879c

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,8 @@ def copy_attr(attr_name):
11961196
copy_attr("do_model_average")
11971197
copy_attr("need_clip")
11981198
copy_attr("no_sync")
1199+
copy_attr("is_firstly_shared")
1200+
11991201
assert param.name not in self._slice_params, (
12001202
f"Duplicate param.name {param.name} appeared, which will caused precision gap and ckpt bug."
12011203
)

python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -746,10 +746,9 @@ def _synchronize_shared_weights(self):
746746
group=comm['group'],
747747
)
748748

749-
for param in comm['layer'].parameters():
750-
if param.name in comm[
751-
'weight_attr'
752-
] and self.global_rank != min(comm['ranks']):
749+
if self.global_rank != min(comm['ranks']):
750+
for weight_attr in comm['weight_attr']:
751+
param = getattr(comm['layer'], weight_attr)
753752
param.is_firstly_shared = False
754753

755754
def allreduce_shared_weight_gradients(self):
@@ -1013,11 +1012,11 @@ def flush_into_run_function():
10131012
flush_into_run_function()
10141013
if layer.layer_name not in self.shared_layers:
10151014
self.shared_layers[layer.layer_name] = layer.build_layer()
1016-
for param in self.shared_layers[
1017-
layer.layer_name
1018-
].parameters():
1019-
if param.name in layer.shared_weight_attr:
1020-
param.is_firstly_shared = True
1015+
for weight_attr in layer.shared_weight_attr:
1016+
param = getattr(
1017+
self.shared_layers[layer.layer_name], weight_attr
1018+
)
1019+
param.is_firstly_shared = True
10211020

10221021
if layer.forward_func is None:
10231022
run_function.append(self.shared_layers[layer.layer_name])

0 commit comments

Comments
 (0)