Skip to content

Commit e28be90

Browse files
committed
[feature] add deepseekv2 edp support
1 parent 1c2abf7 commit e28be90

File tree

13 files changed

+497
-53
lines changed

13 files changed

+497
-53
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, kvargs):
6565
self.quant_type = kvargs.get("quant_type", None)
6666
self.quant_cfg_path = kvargs.get("quant_cfg", None)
6767
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
68+
self.expert_parallel_mode = kvargs.get("expert_parallel_mode", "etp")
6869

6970
self._init_datatype()
7071
self._init_config()

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class TransformerLayerInferTpl(TransformerLayerInfer):
1212
""" """
1313

14-
def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
14+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode, tp_split=True):
1515
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
1616
# need to set by subclass
1717
self.eps_ = 1e-5
@@ -21,6 +21,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
2121
self.tp_o_head_num_ = -1
2222
self.head_dim_ = -1
2323
self.embed_dim_ = -1
24+
self.tp_split_ = tp_split
2425
return
2526

2627
def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
@@ -79,7 +80,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w
7980
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
8081
q = None
8182
o = self._get_o(o, infer_state, layer_weight)
82-
if self.world_size_ > 1:
83+
if self.world_size_ > 1 and self.tp_split_:
8384
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
8485
input_embding.add_(o.view(-1, self.embed_dim_))
8586
return
@@ -88,7 +89,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight
8889
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
8990
ffn_out = self._ffn(input1, infer_state, layer_weight)
9091
input1 = None
91-
if self.world_size_ > 1:
92+
if self.world_size_ > 1 and self.tp_split_:
9293
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
9394
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
9495
return
@@ -102,7 +103,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei
102103
o = self._token_attention_kernel(q, infer_state, layer_weight)
103104
q = None
104105
o = self._get_o(o, infer_state, layer_weight)
105-
if self.world_size_ > 1:
106+
if self.world_size_ > 1 and self.tp_split_:
106107
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
107108
input_embding.add_(o.view(-1, self.embed_dim_))
108109
return
@@ -111,7 +112,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
111112
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
112113
ffn_out = self._ffn(input1, infer_state, layer_weight)
113114
input1 = None
114-
if self.world_size_ > 1:
115+
if self.world_size_ > 1 and self.tp_split_:
115116
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
116117
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
117118
return
@@ -125,7 +126,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn
125126
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
126127
q = None
127128
o = self._get_o(o, infer_state, layer_weight)
128-
if self.world_size_ > 1:
129+
if self.world_size_ > 1 and self.tp_split_:
129130
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
130131
input_embding.add_(o.view(-1, self.embed_dim_))
131132
return
@@ -134,7 +135,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l
134135
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
135136
ffn_out = self._ffn(input1, infer_state, layer_weight)
136137
input1 = None
137-
if self.world_size_ > 1:
138+
if self.world_size_ > 1 and self.tp_split_:
138139
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
139140
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
140141
return

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
MultiCOLMMWeight,
1212
ROWBMMWeight,
1313
COLBMMWeight,
14+
MultiCOLMMWeightNoTp,
15+
ROWBMMWeightNoTp,
16+
COLBMMWeightNoTp,
17+
COLMMWeightNoTp
1418
)
1519
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
1620
from .fused_moe_weight import FusedMoeWeight

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class FusedMoeWeight(BaseWeight):
1717
def __init__(
18-
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type
18+
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type, expert_parallel_mode="etp"
1919
):
2020
super().__init__()
2121
assert HAS_VLLM, "vllm is not installed, you can't use FusedMoeWeight"
@@ -33,6 +33,7 @@ def __init__(
3333
self.expert_down_proj_etp = None
3434
self.w2_list = [None] * self.n_routed_experts
3535
self.quant_method = None
36+
self.expert_parallel_mode = expert_parallel_mode
3637
self.lock = threading.Lock()
3738

3839
def set_quant_method(self, quant_method):
@@ -159,7 +160,7 @@ def _load_hf_weights_etp(self, weights):
159160
self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep]
160161

161162
def load_hf_weights(self, weights):
162-
if os.environ.get("ETP_MODE_ENABLED") == "true":
163+
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
163164
self._load_hf_weights_etp(weights)
164165
else:
165166
for i_experts in range(self.n_routed_experts):
@@ -190,7 +191,7 @@ def _cuda(self, cpu_tensor):
190191
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
191192

192193
def verify_load(self):
193-
if os.environ.get("ETP_MODE_ENABLED") == "true":
194+
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
194195
return True
195196
else:
196197
return self.w1 is not None and self.w2 is not None

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,24 @@ def load_hf_weights(self, weights):
9696
self._post_load_weights()
9797
return
9898

99+
class COLMMWeightNoTp(MMWeight):
100+
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
101+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
102+
self.start = 0
103+
self.end = split_n_embed
104+
105+
def load_hf_weights(self, weights):
106+
weight = None
107+
if self.weight_name in weights:
108+
weight = weights[self.weight_name].to(self.data_type_)
109+
self.weight = weight[:, self.start : self.end]
110+
if self.bias_name in weights:
111+
bias = weights[self.bias_name]
112+
self.bias = bias.to(self.data_type_).cuda(self.tp_rank_)
113+
if weight is None:
114+
return
115+
self._post_load_weights()
116+
return
99117

100118
class MultiMMWeight(MMWeightTpl):
101119
def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
@@ -172,6 +190,21 @@ def load_hf_weights(self, weights):
172190
self._fuse()
173191
return
174192

193+
class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP):
194+
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
195+
super().__init__(weight_names, data_type, split_n_embed, bias_names)
196+
197+
def load_hf_weights(self, weights):
198+
weight = None
199+
for i in range(len(self.weight_names)):
200+
if self.weight_names[i] in weights:
201+
weight = weights[self.weight_names[i]].to(self.data_type_)
202+
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
203+
if self.has_bias and self.bias_names[i] in weights:
204+
bias = weights[self.bias_names[i]].to(self.data_type_)
205+
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
206+
self._fuse()
207+
return
175208

176209
class BMMWeightTpl(BaseWeightTpl):
177210
def __init__(self, data_type):
@@ -233,6 +266,19 @@ def __init__(
233266
):
234267
super().__init__(weight_name, data_type, split_n_embed, bias_name)
235268

269+
class ROWBMMWeightNoTp(BMMWeight):
270+
load_hf_weights = ROWMMWeight.load_hf_weights
271+
272+
def __init__(
273+
self,
274+
weight_name,
275+
data_type,
276+
split_n_embed,
277+
bias_name=None,
278+
):
279+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
280+
self.start = 0
281+
self.end = split_n_embed
236282

237283
class COLBMMWeight(BMMWeight):
238284
load_hf_weights = COLMMWeight.load_hf_weights
@@ -248,3 +294,21 @@ def __init__(
248294

249295
def _post_load_weights(self):
250296
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
297+
298+
class COLBMMWeightNoTp(BMMWeight):
299+
load_hf_weights = COLMMWeightNoTp.load_hf_weights
300+
301+
def __init__(
302+
self,
303+
weight_name,
304+
data_type,
305+
split_n_embed,
306+
bias_name=None,
307+
):
308+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
309+
self.start = 0
310+
self.end = split_n_embed
311+
312+
def _post_load_weights(self):
313+
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
314+

0 commit comments

Comments
 (0)