Skip to content

Commit 2831c3a

Browse files
authored
update blockedfp8 scale name (#3532)
* update blockedfp8 scale name * fix lint * fix deepseek _load_weight_attention
1 parent 421b113 commit 2831c3a

File tree

6 files changed

+44
-54
lines changed

6 files changed

+44
-54
lines changed

lmdeploy/pytorch/models/deepseek_v2.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1273,8 +1273,8 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
12731273
if name.endswith('.weight'):
12741274
weight_name = name
12751275
scale_name = name.replace('.weight', '.scale')
1276-
elif name.endswith('.scale'):
1277-
weight_name = name.replace('.scale', '.weight')
1276+
elif name.endswith('.weight_scale_inv'):
1277+
weight_name = name.replace('.weight_scale_inv', '.weight')
12781278
scale_name = name
12791279
self._load_buffers[name] = loaded_weight
12801280
if (weight_name in self._load_buffers and scale_name in self._load_buffers):
@@ -1288,7 +1288,7 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
12881288
for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:
12891289
if mod_name not in name:
12901290
continue
1291-
if name.endswith('.scale'):
1291+
if name.endswith('.weight_scale_inv'):
12921292
weight = loaded_weight
12931293
else:
12941294
loaded_weight = loaded_weight.to(device)
@@ -1328,8 +1328,6 @@ def __skip_nextn(name, nextn_keys):
13281328
('.gate_up_proj', '.up_proj', 1),
13291329
]
13301330

1331-
scale_suffix = '.weight_scale_inv'
1332-
13331331
config = self.config
13341332

13351333
update_pe_mapping = []
@@ -1375,8 +1373,7 @@ def __skip_nextn(name, nextn_keys):
13751373
continue
13761374
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
13771375
continue
1378-
if name.endswith(scale_suffix):
1379-
name = name[:-len(scale_suffix)] + '.scale'
1376+
13801377
if '.experts' in name:
13811378
self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
13821379
elif '.self_attn' in name and getattr(config, 'use_mla', True):

lmdeploy/pytorch/models/internlm3.py

-3
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403403
('.gate_up_proj', '.up_proj', 1),
404404
]
405405

406-
scale_suffix = '.weight_scale_inv'
407406
params_dict = dict(self.named_parameters())
408407
for name, loaded_weight in weights:
409408
if 'rotary_emb.inv_freq' in name:
@@ -412,8 +411,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
412411
continue
413412
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
414413
continue
415-
if name.endswith(scale_suffix):
416-
name = name[:-len(scale_suffix)] + '.scale'
417414

418415
for (param_name, weight_name, shard_id) in stacked_params_mapping:
419416
if weight_name not in name:

lmdeploy/pytorch/models/qwen3.py

-3
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403403
('.gate_up_proj', '.up_proj', 1),
404404
]
405405

406-
scale_suffix = '.weight_scale_inv'
407406
params_dict = dict(self.named_parameters())
408407
for name, loaded_weight in weights:
409408
if 'rotary_emb.inv_freq' in name:
@@ -412,8 +411,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
412411
continue
413412
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
414413
continue
415-
if name.endswith(scale_suffix):
416-
name = name[:-len(scale_suffix)] + '.scale'
417414

418415
for (param_name, weight_name, shard_id) in stacked_params_mapping:
419416
if weight_name not in name:

lmdeploy/pytorch/models/qwen3_moe.py

-3
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
495495
down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
496496
expert_params_mapping += [gate_param, up_param, down_param]
497497

498-
scale_suffix = '.weight_scale_inv'
499498
params_dict = dict(self.named_parameters())
500499
for name, loaded_weight in weights:
501500
if 'rotary_emb.inv_freq' in name:
@@ -504,8 +503,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
504503
continue
505504
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
506505
continue
507-
if name.endswith(scale_suffix):
508-
name = name[:-len(scale_suffix)] + '.scale'
509506

510507
if '.experts' in name:
511508
self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)

lmdeploy/pytorch/nn/linear.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,16 @@ def __init__(
241241
self.impl = impl_builder.build(in_features, out_features, block_size=128, bias=bias is not None, dtype=dtype)
242242
self.block_size = 128
243243
self.fp8_dtype = fp8_dtype
244-
weight, scale, bias = self.create_weights(in_features, out_features, bias, dtype, device)
244+
weight, weight_scale_inv, bias = self.create_weights(in_features, out_features, bias, dtype, device)
245245
weight = torch.nn.Parameter(weight, requires_grad=False)
246246
weight.weight_loader = self.weight_loader
247-
scale = torch.nn.Parameter(scale, requires_grad=False)
248-
scale.weight_loader = self.weight_loader
247+
weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
248+
weight_scale_inv.weight_loader = self.weight_loader
249249
if bias is not None:
250250
bias = torch.nn.Parameter(bias, requires_grad=False)
251251
bias.weight_loader = self.weight_loader
252252
self.register_parameter('weight', weight)
253-
self.register_parameter('scale', scale)
253+
self.register_parameter('weight_scale_inv', weight_scale_inv)
254254
self.register_parameter('bias', bias)
255255

256256
self.in_features = in_features
@@ -302,27 +302,27 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
302302
def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):
303303
"""create weights."""
304304
weight = torch.empty((out_features, in_features), dtype=self.fp8_dtype, device=device)
305-
scale = torch.empty((div_up(out_features, self.block_size), div_up(in_features, self.block_size)),
306-
dtype=torch.float32,
307-
device=device)
305+
weight_scale_inv = torch.empty((div_up(out_features, self.block_size), div_up(in_features, self.block_size)),
306+
dtype=torch.float32,
307+
device=device)
308308
if bias:
309309
bias = torch.empty((out_features, ), dtype=dtype, device=device)
310310
else:
311311
bias = None
312-
return weight, scale, bias
312+
return weight, weight_scale_inv, bias
313313

314314
def update_weights(self):
315315
"""update weights."""
316-
weight, scale, bias = self.impl.update_weights(self.weight, self.scale, self.bias)
316+
weight, weight_scale_inv, bias = self.impl.update_weights(self.weight, self.weight_scale_inv, self.bias)
317317
weight = torch.nn.Parameter(weight, requires_grad=False)
318318
self.weight.weight_loader = self.weight_loader
319-
scale = torch.nn.Parameter(scale, requires_grad=False)
320-
self.scale.weight_loader = self.weight_loader
319+
weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
320+
self.weight_scale_inv.weight_loader = self.weight_loader
321321
if bias is not None:
322322
bias = torch.nn.Parameter(bias, requires_grad=False)
323323
self.bias.weight_loader = self.weight_loader
324324
self.register_parameter('weight', weight)
325-
self.register_parameter('scale', scale)
325+
self.register_parameter('weight_scale_inv', weight_scale_inv)
326326
self.register_parameter('bias', bias)
327327

328328
def forward(self, x):
@@ -340,11 +340,11 @@ def forward(self, x):
340340
if len(self.lora_adapters) == 0:
341341
if self.dp_scatter:
342342
_, rank = get_tp_world_rank()
343-
return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce, rank, tp_sizes)
343+
return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, rank, tp_sizes)
344344
else:
345-
return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce)
345+
return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce)
346346

347-
out = self.impl.forward(x, self.weight, self.scale, self.bias, False)
347+
out = self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, False)
348348
for lora_adapter in self.lora_adapters.values():
349349
out = lora_adapter(x, out)
350350
if all_reduce:
@@ -394,10 +394,10 @@ def __init__(self,
394394
dp_gather=dp_gather)
395395
self.weight.weight_loader = self.weight_loader
396396
self.weight._weight_type = 'qweight'
397-
self.scale.weight_loader = self.weight_loader
398-
self.scale._weight_type = 'scales'
397+
self.weight_scale_inv.weight_loader = self.weight_loader
398+
self.weight_scale_inv._weight_type = 'scales'
399399
self.weight.weight_spliter = self.weight_spliter
400-
self.scale.weight_spliter = self.weight_spliter
400+
self.weight_scale_inv.weight_spliter = self.weight_spliter
401401
if self.bias is not None:
402402
self.bias.weight_loader = self.weight_loader
403403
self.bias.weight_spliter = self.weight_spliter

lmdeploy/pytorch/nn/moe.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -421,25 +421,25 @@ def __init__(self,
421421
ep=ep,
422422
)
423423
self.block_size = block_size
424-
scale = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)),
425-
dtype=torch.float32,
426-
device=device)
427-
scale = torch.nn.Parameter(scale, requires_grad=False)
428-
self.register_parameter('scale', scale)
424+
weight_scale_inv = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)),
425+
dtype=torch.float32,
426+
device=device)
427+
weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
428+
self.register_parameter('weight_scale_inv', weight_scale_inv)
429429

430430
if self.ep:
431431
self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list))
432-
self.scale.weight_loader = self.weight_loader_scale_ep
432+
self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep
433433
else:
434-
self.scale.weight_loader = self.weight_loader_scale_tp
434+
self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp
435435

436-
def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
436+
def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor):
437437
"""update weight."""
438438
super().update_weight(weight=weight)
439-
weight_loader = self.scale.weight_loader
440-
scale = torch.nn.Parameter(scale, requires_grad=False)
441-
scale.weight_loader = weight_loader
442-
self.register_parameter('scale', scale)
439+
weight_loader = self.weight_scale_inv.weight_loader
440+
weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
441+
weight_scale_inv.weight_loader = weight_loader
442+
self.register_parameter('weight_scale_inv', weight_scale_inv)
443443

444444
def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
445445
shard_id: str):
@@ -545,8 +545,8 @@ def __init__(self,
545545
def update_weights(self):
546546
"""update weights."""
547547
(gate_up_weights, down_weights, gate_up_scale,
548-
down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.scale,
549-
self.down.scale)
548+
down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.weight_scale_inv,
549+
self.down.weight_scale_inv)
550550
self.gate_up.update_weight(gate_up_weights, gate_up_scale)
551551
self.down.update_weight(down_weights, down_scale)
552552

@@ -628,8 +628,9 @@ def gemm(self, state: Dict):
628628
if moe_type == MoeType.DSAsyncPrefill:
629629
if state['recv_hidden_states'].shape[0] > 0:
630630
state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
631-
self.gate_up.scale, self.down.weight,
632-
self.down.scale)
631+
self.gate_up.weight_scale_inv,
632+
self.down.weight,
633+
self.down.weight_scale_inv)
633634
gemm_state = {
634635
'fusedmoe': state['fusedmoe'],
635636
'hidden_states': state['recv_hidden_states'],
@@ -638,8 +639,9 @@ def gemm(self, state: Dict):
638639
}
639640
elif moe_type == MoeType.DSAsyncDecode:
640641
state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
641-
self.gate_up.scale, self.down.weight,
642-
self.down.scale)
642+
self.gate_up.weight_scale_inv,
643+
self.down.weight,
644+
self.down.weight_scale_inv)
643645
gemm_state = {
644646
'fusedmoe': state['fusedmoe'],
645647
'hidden_states': state['recv_hidden_states'],
@@ -650,8 +652,8 @@ def gemm(self, state: Dict):
650652
}
651653
else: # MoeType.Default
652654
hidden_states = self.impl.forward(state['hidden_states'], state['topk_weights'], state['topk_idx'],
653-
self.gate_up.weight, self.gate_up.scale, self.down.weight,
654-
self.down.scale, self.expert_list)
655+
self.gate_up.weight, self.gate_up.weight_scale_inv, self.down.weight,
656+
self.down.weight_scale_inv, self.expert_list)
655657
gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']}
656658
return gemm_state
657659

0 commit comments

Comments
 (0)