Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/paddlefleet/pipeline_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .pipeline_parallel import NoPipelineParallel
from .pipeline_parallel import NoPipelineParallel, PipelineParallel
from .pp_layers import LayerDesc, PipelineLayer, SharedLayerDesc

__all__ = [
"LayerDesc",
"SharedLayerDesc",
"PipelineLayer",
"NoPipelineParallel",
"PipelineParallel",
]
11 changes: 9 additions & 2 deletions src/paddlefleet/pipeline_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import paddle
from paddle import nn

from .pp_layers import PipelineLayer


class NoPipelineParallel(nn.Layer):
def __init__(self, layers, strategy):
assert isinstance(layers, PipelineLayer)
super().__init__()
self._layers = layers
self._strategy = strategy
Expand All @@ -41,8 +44,12 @@ def forward_backward_pipeline(


class PipelineParallel(nn.Layer):
def __init__(self, layer, hcg, strategy):
pass
def __init__(self, layers, hcg, strategy):
assert isinstance(layers, PipelineLayer)
super().__init__()
self._layers = layers
self._strategy = strategy
self._hcg = hcg

def forward(self, data, scaler=None, return_micro_batch_loss=False):
pass
300 changes: 300 additions & 0 deletions tests/multi_card_tests/pipeline_parallel/test_pp_with_shared_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import unittest
from dataclasses import dataclass

import numpy as np

os.environ["FLAGS_profile_optimizer_details_steps"] = "1"
import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.distributed import fleet
from paddle.nn import Layer

from paddlefleet.pipeline_parallel import (
LayerDesc,
PipelineLayer,
PipelineParallel,
SharedLayerDesc,
)
from paddlefleet.spec_utils import LayerSpec, build_layer


def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)


batch_size = 8
micro_batch_size = 2
vocab_size = 128
hidden_size = 16


class SimpleNetBase(Layer):
def __init__(self):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)

self.softmax_weight = self.create_parameter(
shape=[hidden_size, vocab_size]
)
self.softmax_bias = self.create_parameter(
shape=[vocab_size], is_bias=False
)

def forward(self, x1, x2, y1):
x_emb = self.word_embeddings(x1)
fc = paddle.matmul(x_emb, self.softmax_weight)
fc = paddle.add(fc, self.softmax_bias)
projection = paddle.reshape(fc, shape=[-1, vocab_size])

projection = paddle.matmul(projection, self.word_embeddings.weight)

loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=projection, label=y1, soft_label=False
)
return loss.mean()


class EmbeddingPipe(Layer):
def __init__(self):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)

@property
def embedding_weight(self):
return self.word_embeddings.weight

def forward(self, args):
x1, x2 = args
x_emb = self.word_embeddings(x1)
return x_emb, x2


class MatmulNet(Layer):
def __init__(self):
super().__init__()
self.softmax_weight = self.create_parameter(
shape=[hidden_size, vocab_size]
)

def forward(self, args):
x1, x2 = args
fc = paddle.matmul(x1, self.softmax_weight)

return fc, x2


class BiasNet(Layer):
def __init__(self):
super().__init__()
self.softmax_bias = self.create_parameter(shape=[vocab_size])

def forward(self, args):
fc, x2 = args
fc = paddle.add(fc, self.softmax_bias)
projection = paddle.reshape(fc, shape=[-1, vocab_size])
return projection, x2


class LossNet(Layer):
def __init__(self):
super().__init__()

def forward(self, args, y1):
projection = args
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=projection, label=y1[0], soft_label=False
)
return loss.mean()


@dataclass
class SimpleNetSpec:
word_embeddings: LayerSpec
matmul_net: LayerSpec
bias_net: LayerSpec


class SimpleNet(PipelineLayer):
def __init__(self, sublayers_spec: SimpleNetSpec, **kwargs):
self.layers = SimpleNet.get_layer_desc_list(sublayers_spec)

super().__init__(layers=self.layers, **kwargs)

@staticmethod
def get_layer_desc_list(spec: SimpleNetSpec):
def _logits_helper(embedding, output):
return paddle.matmul(output[0], embedding.embedding_weight)

layers = [
SharedLayerDesc(
"embed",
spec.word_embeddings,
shared_weight_attr="embedding_weight",
),
LayerDesc(spec.matmul_net),
LayerDesc(spec.bias_net),
SharedLayerDesc(
"embed",
spec.word_embeddings,
forward_func=_logits_helper,
shared_weight_attr="embedding_weight",
),
]
return layers


def get_simple_net_spec():
spec = LayerSpec(
layer=SimpleNet,
sublayers_spec=SimpleNetSpec(
word_embeddings=LayerSpec(layer=EmbeddingPipe),
matmul_net=LayerSpec(layer=MatmulNet),
bias_net=LayerSpec(layer=BiasNet),
),
extra_kwargs={
"loss_fn": LossNet(),
},
)
return spec


class TestDistEmbeddingTraining(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
strategy.hybrid_configs["pp_configs"].clear_every_step_cache = True
self.strategy = strategy

fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)

# construct model a
model_a = SimpleNetBase()
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True
)
optimizer_a = paddle.optimizer.SGD(
learning_rate=scheduler_a, parameters=model_a.parameters()
)

simple_net_spec = get_simple_net_spec()
model_b = build_layer(simple_net_spec, topology=hcg.topology())

scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True
)
optimizer_b = paddle.optimizer.SGD(
learning_rate=scheduler_b, parameters=model_b.parameters()
)
model_b = PipelineParallel(model_b, hcg, self.strategy)
optimizer_b = fleet.distributed_optimizer(optimizer_b)

parameters = []
for param in model_a.parameters():
parameters.append(param.numpy())

model_b_params = model_b.parameters()

if pp_id == 0:
model_b_params[0].set_value(parameters[2])
model_b_params[1].set_value(parameters[0])

else:
model_b_params[0].set_value(parameters[2])
model_b_params[1].set_value(parameters[1])

# enable this test when simple pp is ready
return

for step in range(5):
x1_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
x2_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
y1_data = np.random.randint(0, hidden_size, size=[batch_size, 1])

x1 = paddle.to_tensor(x1_data)
x2 = paddle.to_tensor(x2_data)
y1 = paddle.to_tensor(y1_data)

x1.stop_gradient = True
x2.stop_gradient = True
y1.stop_gradient = True

loss_a = model_a(x1, x2, y1)
loss_a.backward()

optimizer_a.step()
optimizer_a.clear_grad()
scheduler_a.step()

loss_b = model_b.train_batch(
[(x1, x2), (y1,)], optimizer_b, scheduler_b
)

print("loss", loss_a.numpy(), loss_b.numpy())
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())


class TestDistEmbeddingTrainingWithSync(TestDistEmbeddingTraining):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
strategy.hybrid_configs["pp_configs"].clear_every_step_cache = True
strategy.hybrid_configs["pp_configs"].sync_moment = True
strategy.hybrid_configs["pp_configs"].sync_param = True
self.strategy = strategy

fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
super().test_pp_model()


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ tests:
- test_case: [tests/multi_card_tests/tensor_parallel/*.py]
products:
- num_gpus: 4
- test_case: [tests/multi_card_tests/pipeline_parallel/test_pp_layer.py]
- test_case: [tests/multi_card_tests/pipeline_parallel/*.py]
products:
- num_gpus: 2
- test_case: [tests/test/*.py]
Expand Down