Skip to content

Commit 4fa6b95

Browse files
authored
[moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)
1 parent 63314ce commit 4fa6b95

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

colossalai/shardformer/modeling/deepseek.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ def setup_process_groups(
109109
for p in self.experts.parameters():
110110
set_moe_tensor_ep_group(p, ep_group)
111111

112+
if self.config.n_shared_experts is not None:
113+
self.shared_experts.gate_proj = Linear1D_Col.from_native_module(
114+
self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
115+
)
116+
117+
self.shared_experts.up_proj = Linear1D_Col.from_native_module(
118+
self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication
119+
)
120+
121+
self.shared_experts.down_proj = Linear1D_Row.from_native_module(
122+
self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication
123+
)
124+
112125
@staticmethod
113126
def from_native_module(
114127
module,

tests/test_shardformer/test_model/test_shard_deepseek.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
NUM_BATCH = 8
2121
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
2222
NUM_LAYERS = 4
23-
HIDDEN_SIZE_PER_HEAD = 4
23+
HIDDEN_SIZE_PER_HEAD = 8
2424
NUM_HEADS = 8
2525
TOP_K = 2
2626

2727

28-
def run_deepseek_commom(config: Tuple[int, ...]):
28+
def run_deepseek_commom(parallel_config: Tuple[int, ...]):
2929
Randomizer.reset_index()
30-
stage, ep_size, pp_size, tp_size, sp_size = config
30+
print(f"rank {dist.get_rank()} testing {parallel_config}")
31+
stage, ep_size, pp_size, tp_size, sp_size = parallel_config
3132
world_size = dist.get_world_size()
3233
rank = dist.get_rank()
3334
dtype, precision = torch.bfloat16, "bf16"
@@ -65,6 +66,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
6566
attn_implementation="flash_attention_2",
6667
torch_dtype="float16",
6768
n_routed_experts=NUM_EXPERTS,
69+
n_shared_experts=2,
6870
num_experts_per_tok=TOP_K,
6971
trust_remote_code=True,
7072
)
@@ -159,7 +161,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
159161
if rank == world_size - 1:
160162
shutil.rmtree(model_dir)
161163

162-
print(f"rank {dist.get_rank()} test passed")
164+
print(f"rank {dist.get_rank()} passed {parallel_config}")
163165

164166

165167
@parameterize(

0 commit comments

Comments
 (0)