Skip to content

Commit b330e5a

Browse files
authored
Shared weight test (#86)
1 parent c936648 commit b330e5a

File tree

4 files changed

+312
-4
lines changed

4 files changed

+312
-4
lines changed

src/paddlefleet/pipeline_parallel/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .pipeline_parallel import NoPipelineParallel
15+
from .pipeline_parallel import NoPipelineParallel, PipelineParallel
1616
from .pp_layers import LayerDesc, PipelineLayer, SharedLayerDesc
1717

1818
__all__ = [
1919
"LayerDesc",
2020
"SharedLayerDesc",
2121
"PipelineLayer",
2222
"NoPipelineParallel",
23+
"PipelineParallel",
2324
]

src/paddlefleet/pipeline_parallel/pipeline_parallel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import paddle
1616
from paddle import nn
1717

18+
from .pp_layers import PipelineLayer
19+
1820

1921
class NoPipelineParallel(nn.Layer):
2022
def __init__(self, layers, strategy):
23+
assert isinstance(layers, PipelineLayer)
2124
super().__init__()
2225
self._layers = layers
2326
self._strategy = strategy
@@ -41,8 +44,12 @@ def forward_backward_pipeline(
4144

4245

4346
class PipelineParallel(nn.Layer):
44-
def __init__(self, layer, hcg, strategy):
45-
pass
47+
def __init__(self, layers, hcg, strategy):
48+
assert isinstance(layers, PipelineLayer)
49+
super().__init__()
50+
self._layers = layers
51+
self._strategy = strategy
52+
self._hcg = hcg
4653

4754
def forward(self, data, scaler=None, return_micro_batch_loss=False):
4855
pass
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import random
16+
import unittest
17+
from dataclasses import dataclass
18+
19+
import numpy as np
20+
21+
os.environ["FLAGS_profile_optimizer_details_steps"] = "1"
22+
import paddle
23+
import paddle.distributed as dist
24+
from paddle import nn
25+
from paddle.distributed import fleet
26+
from paddle.nn import Layer
27+
28+
from paddlefleet.pipeline_parallel import (
29+
LayerDesc,
30+
PipelineLayer,
31+
PipelineParallel,
32+
SharedLayerDesc,
33+
)
34+
from paddlefleet.spec_utils import LayerSpec, build_layer
35+
36+
37+
def set_random_seed(seed, dp_id, rank_id):
38+
"""Set random seed for reproducibility."""
39+
random.seed(seed)
40+
np.random.seed(seed + dp_id)
41+
paddle.seed(seed + dp_id)
42+
43+
44+
batch_size = 8
45+
micro_batch_size = 2
46+
vocab_size = 128
47+
hidden_size = 16
48+
49+
50+
class SimpleNetBase(Layer):
51+
def __init__(self):
52+
super().__init__()
53+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
54+
55+
self.softmax_weight = self.create_parameter(
56+
shape=[hidden_size, vocab_size]
57+
)
58+
self.softmax_bias = self.create_parameter(
59+
shape=[vocab_size], is_bias=False
60+
)
61+
62+
def forward(self, x1, x2, y1):
63+
x_emb = self.word_embeddings(x1)
64+
fc = paddle.matmul(x_emb, self.softmax_weight)
65+
fc = paddle.add(fc, self.softmax_bias)
66+
projection = paddle.reshape(fc, shape=[-1, vocab_size])
67+
68+
projection = paddle.matmul(projection, self.word_embeddings.weight)
69+
70+
loss = paddle.nn.functional.softmax_with_cross_entropy(
71+
logits=projection, label=y1, soft_label=False
72+
)
73+
return loss.mean()
74+
75+
76+
class EmbeddingPipe(Layer):
77+
def __init__(self):
78+
super().__init__()
79+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
80+
81+
@property
82+
def embedding_weight(self):
83+
return self.word_embeddings.weight
84+
85+
def forward(self, args):
86+
x1, x2 = args
87+
x_emb = self.word_embeddings(x1)
88+
return x_emb, x2
89+
90+
91+
class MatmulNet(Layer):
92+
def __init__(self):
93+
super().__init__()
94+
self.softmax_weight = self.create_parameter(
95+
shape=[hidden_size, vocab_size]
96+
)
97+
98+
def forward(self, args):
99+
x1, x2 = args
100+
fc = paddle.matmul(x1, self.softmax_weight)
101+
102+
return fc, x2
103+
104+
105+
class BiasNet(Layer):
106+
def __init__(self):
107+
super().__init__()
108+
self.softmax_bias = self.create_parameter(shape=[vocab_size])
109+
110+
def forward(self, args):
111+
fc, x2 = args
112+
fc = paddle.add(fc, self.softmax_bias)
113+
projection = paddle.reshape(fc, shape=[-1, vocab_size])
114+
return projection, x2
115+
116+
117+
class LossNet(Layer):
118+
def __init__(self):
119+
super().__init__()
120+
121+
def forward(self, args, y1):
122+
projection = args
123+
loss = paddle.nn.functional.softmax_with_cross_entropy(
124+
logits=projection, label=y1[0], soft_label=False
125+
)
126+
return loss.mean()
127+
128+
129+
@dataclass
130+
class SimpleNetSpec:
131+
word_embeddings: LayerSpec
132+
matmul_net: LayerSpec
133+
bias_net: LayerSpec
134+
135+
136+
class SimpleNet(PipelineLayer):
137+
def __init__(self, sublayers_spec: SimpleNetSpec, **kwargs):
138+
self.layers = SimpleNet.get_layer_desc_list(sublayers_spec)
139+
140+
super().__init__(layers=self.layers, **kwargs)
141+
142+
@staticmethod
143+
def get_layer_desc_list(spec: SimpleNetSpec):
144+
def _logits_helper(embedding, output):
145+
return paddle.matmul(output[0], embedding.embedding_weight)
146+
147+
layers = [
148+
SharedLayerDesc(
149+
"embed",
150+
spec.word_embeddings,
151+
shared_weight_attr="embedding_weight",
152+
),
153+
LayerDesc(spec.matmul_net),
154+
LayerDesc(spec.bias_net),
155+
SharedLayerDesc(
156+
"embed",
157+
spec.word_embeddings,
158+
forward_func=_logits_helper,
159+
shared_weight_attr="embedding_weight",
160+
),
161+
]
162+
return layers
163+
164+
165+
def get_simple_net_spec():
166+
spec = LayerSpec(
167+
layer=SimpleNet,
168+
sublayers_spec=SimpleNetSpec(
169+
word_embeddings=LayerSpec(layer=EmbeddingPipe),
170+
matmul_net=LayerSpec(layer=MatmulNet),
171+
bias_net=LayerSpec(layer=BiasNet),
172+
),
173+
extra_kwargs={
174+
"loss_fn": LossNet(),
175+
},
176+
)
177+
return spec
178+
179+
180+
class TestDistEmbeddingTraining(unittest.TestCase):
181+
def setUp(self):
182+
strategy = fleet.DistributedStrategy()
183+
self.model_parallel_size = 1
184+
self.data_parallel_size = 1
185+
self.pipeline_parallel_size = 2
186+
strategy.hybrid_configs = {
187+
"dp_degree": self.data_parallel_size,
188+
"mp_degree": self.model_parallel_size,
189+
"pp_degree": self.pipeline_parallel_size,
190+
}
191+
strategy.pipeline_configs = {
192+
"accumulate_steps": batch_size // micro_batch_size,
193+
"micro_batch_size": micro_batch_size,
194+
}
195+
strategy.hybrid_configs["pp_configs"].clear_every_step_cache = True
196+
self.strategy = strategy
197+
198+
fleet.init(is_collective=True, strategy=strategy)
199+
200+
def test_pp_model(self):
201+
hcg = fleet.get_hybrid_communicate_group()
202+
dp_id = hcg.get_data_parallel_rank()
203+
pp_id = hcg.get_stage_id()
204+
rank_id = dist.get_rank()
205+
set_random_seed(1024, dp_id, rank_id)
206+
207+
# construct model a
208+
model_a = SimpleNetBase()
209+
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
210+
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True
211+
)
212+
optimizer_a = paddle.optimizer.SGD(
213+
learning_rate=scheduler_a, parameters=model_a.parameters()
214+
)
215+
216+
simple_net_spec = get_simple_net_spec()
217+
model_b = build_layer(simple_net_spec, topology=hcg.topology())
218+
219+
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
220+
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True
221+
)
222+
optimizer_b = paddle.optimizer.SGD(
223+
learning_rate=scheduler_b, parameters=model_b.parameters()
224+
)
225+
model_b = PipelineParallel(model_b, hcg, self.strategy)
226+
optimizer_b = fleet.distributed_optimizer(optimizer_b)
227+
228+
parameters = []
229+
for param in model_a.parameters():
230+
parameters.append(param.numpy())
231+
232+
model_b_params = model_b.parameters()
233+
234+
if pp_id == 0:
235+
model_b_params[0].set_value(parameters[2])
236+
model_b_params[1].set_value(parameters[0])
237+
238+
else:
239+
model_b_params[0].set_value(parameters[2])
240+
model_b_params[1].set_value(parameters[1])
241+
242+
# enable this test when simple pp is ready
243+
return
244+
245+
for step in range(5):
246+
x1_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
247+
x2_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
248+
y1_data = np.random.randint(0, hidden_size, size=[batch_size, 1])
249+
250+
x1 = paddle.to_tensor(x1_data)
251+
x2 = paddle.to_tensor(x2_data)
252+
y1 = paddle.to_tensor(y1_data)
253+
254+
x1.stop_gradient = True
255+
x2.stop_gradient = True
256+
y1.stop_gradient = True
257+
258+
loss_a = model_a(x1, x2, y1)
259+
loss_a.backward()
260+
261+
optimizer_a.step()
262+
optimizer_a.clear_grad()
263+
scheduler_a.step()
264+
265+
loss_b = model_b.train_batch(
266+
[(x1, x2), (y1,)], optimizer_b, scheduler_b
267+
)
268+
269+
print("loss", loss_a.numpy(), loss_b.numpy())
270+
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
271+
272+
273+
class TestDistEmbeddingTrainingWithSync(TestDistEmbeddingTraining):
274+
def setUp(self):
275+
strategy = fleet.DistributedStrategy()
276+
self.model_parallel_size = 1
277+
self.data_parallel_size = 1
278+
self.pipeline_parallel_size = 2
279+
strategy.hybrid_configs = {
280+
"dp_degree": self.data_parallel_size,
281+
"mp_degree": self.model_parallel_size,
282+
"pp_degree": self.pipeline_parallel_size,
283+
}
284+
strategy.pipeline_configs = {
285+
"accumulate_steps": batch_size // micro_batch_size,
286+
"micro_batch_size": micro_batch_size,
287+
}
288+
strategy.hybrid_configs["pp_configs"].clear_every_step_cache = True
289+
strategy.hybrid_configs["pp_configs"].sync_moment = True
290+
strategy.hybrid_configs["pp_configs"].sync_param = True
291+
self.strategy = strategy
292+
293+
fleet.init(is_collective=True, strategy=strategy)
294+
295+
def test_pp_model(self):
296+
super().test_pp_model()
297+
298+
299+
if __name__ == "__main__":
300+
unittest.main()

tests/test_configs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tests:
44
- test_case: [tests/multi_card_tests/tensor_parallel/*.py]
55
products:
66
- num_gpus: 4
7-
- test_case: [tests/multi_card_tests/pipeline_parallel/test_pp_layer.py]
7+
- test_case: [tests/multi_card_tests/pipeline_parallel/*.py]
88
products:
99
- num_gpus: 2
1010
- test_case: [tests/test/*.py]

0 commit comments

Comments
 (0)