-
Notifications
You must be signed in to change notification settings - Fork 722
Expand file tree
/
Copy pathlinear.py
More file actions
1332 lines (1174 loc) · 54 KB
/
linear.py
File metadata and controls
1332 lines (1174 loc) · 54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import numpy as np
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import (
decode_alltoall_transpose,
tensor_model_parallel_all_reduce,
)
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
h2d_copy,
process_weight_transpose,
set_weight_attrs,
slice_fn,
)
from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor, modules_to_convert
class UnquantizedLinearMethod(QuantMethodBase):
"""Linear method without quantization."""
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: axis along which to split the tensor in a distributed environment
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
"""
self.model_format = extra_weight_attrs.get("model_format")
self.weight_shape = (
layer.weight_shape[::-1] if extra_weight_attrs.get("model_format") == "torch" else layer.weight_shape
)
layer.weight = layer.create_parameter(
shape=self.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
set_weight_attrs(
layer.weight,
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
},
)
def process_weights_after_loading(self, layer):
if self.model_format == "torch":
process_weight_transpose(layer, "weight")
def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if layer.weight.dtype != weights.dtype:
weights = weights.cast(layer.weight.dtype)
layer.weight.set_value(weights)
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out
class LinearBase(nn.Layer):
"""
LinearBase Layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
Raises:
NotImplementedError: Raised if the current platform is not a CUDA platform.
"""
super().__init__()
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
or current_platform.is_intel_hpu()
):
self.forward = self.forward_cuda
else:
raise NotImplementedError
self.fd_config = fd_config
self.skip_quant = skip_quant
self.input_size = input_size
self.output_size = output_size
self.with_bias = with_bias
self.prefix = prefix
self.is_quantized = fd_config.model_config.is_quantized and not (
fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.dense_quant_type is None
)
# key
if weight_key:
self.weight_key = f"{prefix}.{weight_key}"
elif self.is_quantized and not skip_quant:
self.weight_key = f"{prefix}.quant_weight"
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
else:
self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias"
self.shift_key = f"{prefix}.shift_bias"
self.smooth_key = f"{prefix}.smooth_weight"
self.out_scale_key = f"{prefix}.out_scale"
self._dtype = self._helper.get_default_dtype()
if weight_dtype:
self.weight_dtype = weight_dtype
elif self.skip_quant:
self.weight_dtype = self._dtype
else:
self.weight_dtype = self._dtype
self.weight_shape = [
self.input_size,
self.output_size,
]
if (
fd_config.quant_config
and not skip_quant
and modules_to_convert(prefix, self.fd_config)
and fd_config.quant_config.get_quant_method(self)
):
self.quant_method = fd_config.quant_config.get_quant_method(self)
else:
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
self.bias = None
if self.with_bias:
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self.weight_dtype,
is_bias=True,
)
setattr(
self.bias,
"weight_loader",
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config),
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
def load_prequant_weight(self, state_dict: dict):
"""
Load the prequantized weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the prequantized weights and scales.
"""
if isinstance(self.quant_method, UnquantizedLinearMethod):
# for gate
self.load_weight(state_dict)
else:
self.quant_method.process_prequanted_weights(self, state_dict)
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the weights
"""
if "qkv_a_proj_with_mqa" in self.weight_key:
self.weight_key_q = self.weight_key.replace("qkv_a_proj_with_mqa", "q_a_proj")
self.weight_key_kv = self.weight_key.replace("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa")
q_weight_tensor = get_tensor(state_dict.pop(self.weight_key_q))
kv_weight_tensor = get_tensor(state_dict.pop(self.weight_key_kv))
weight_tensor = paddle.concat([q_weight_tensor, kv_weight_tensor], axis=-1)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
self.quant_method.process_loaded_weights(self, weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
self.state_dict = state_dict
assert self.weight_key is not None, "weight_key should not be None."
if self.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)
# bias
if self.with_bias:
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.bias.set_value(bias_tensor)
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
"""
Forward function for Linear.
Args:
x (Tensor): Input tensor to the Linear.
Returns:
Tensor: Output tensor.
Raises:
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
"""
if self.weight_dtype == "float32":
linear_out = self.quant_method.apply(self, x.cast("float32"))
else:
linear_out = self.quant_method.apply(self, x)
return linear_out
class ReplicatedLinear(LinearBase):
"""
ReplicatedLinear Layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
model_format: Optional[str] = None,
):
"""
Initializes a replicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.hidden_size = fd_config.model_config.hidden_size
assert self.quant_method is not None
self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format if model_format is None else model_format,
)
class MergedReplicatedLinear(ReplicatedLinear):
"""
MergedReplicatedLinear linear layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_sizes: list[int] = None,
with_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a mergedreplicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_sizes (list[int]): Number of output features list. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=sum(output_sizes),
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
if not param._is_initialized():
param.initialize()
# for xpu and other backend
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
axis = -1
else:
axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=0, end=loaded_weight.shape[axis])
else:
assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
param_dim = True
else:
param_dim = (self.fd_config.model_config.model_format == "torch") ^ True
if loaded_shard_id in ["q_a", "gate"]:
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
elif loaded_shard_id in ["kv_a", "up"]:
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
param_dim,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
# (bukejiyu) After this fix, the early H2D copy for non-GPU devices is no longer needed and can be safely removed.
h2d_copy(param, loaded_weight)
class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.input_size = input_size
self.output_size = divide(output_size, self.tp_size) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=self.input_size,
output_size=self.output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
assert self.quant_method is not None
self.quant_method.create_weights(
self,
output_dim=True,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
if self.tp_size > 0:
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
set_weight_attrs(
self.bias,
{
"output_dim": True,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
class MergedColumnParallelLinear(ColumnParallelLinear):
"""
MergedColumnParallelLinear Layer.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str,
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
activation: str = "gelu",
skip_quant: bool = False,
):
"""
Initialize the fused up_gate_proj Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
activation (str): Activation function to use. Defaults to "gelu".
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.output_size = output_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
skip_quant=skip_quant,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# for xpu and other backend
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
shard_dim = -1 if output_dim else 0
output_size = param.shape[shard_dim]
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk.
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size * self.tp_size // 2),
("up", output_size * self.tp_size // 2, output_size * self.tp_size // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# split gate up
assert loaded_shard_id in ["gate", "up"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.tp_size
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
if loaded_shard_id == "gate":
param_shard_offset = 0
else:
# loaded_shard_id == "up"
param_shard_offset = param_shard_size
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, "weight_key should not be None."
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
gate_weight_key = self.weight_key.replace("up_gate_proj", "gate_proj")
up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj")
gate_tensor = get_tensor(state_dict.pop(gate_weight_key))
up_tensor = get_tensor(state_dict.pop(up_weight_key))
weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1)
if self.with_bias:
gate_bias_key = self.bias_key.replace("up_gate_proj", "gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(paddle.get_default_dtype())
state_dict[self.bias_key] = bias_tensor
state_dict[self.weight_key] = weight_tensor
super().load_state_dict(state_dict)
class QKVParallelLinear(ColumnParallelLinear):
"""
QKVParallelLinear Layer.
"""
def __init__(
self,
fd_config,
prefix,
with_bias=False,
num_heads: Optional[int] = None,
kv_num_heads: Optional[int] = None,
hidden_size: Optional[int] = None,
head_dim: Optional[int] = None,
skip_quant: bool = False,
weight_dtype: str = "",
):
"""
Initialize the QKV Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
with_bias (bool): Whether to include bias or not. Defaults to False.
num_heads (Optional[int]): Number of attention heads in the model.
kv_num_heads (Optional[int]): Number of key/value heads, used for multi-query or grouped-query attention.
hidden_size (Optional[int]): Total hidden layer dimension, typically the embedding size.
head_dim (Optional[int]): Size of each attention head, usually computed as hidden_size divided by num_heads.
"""
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and not self.fd_config.load_config.is_pre_sharded:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the weights
"""
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
q_weight_key = self.weight_key.replace("qkv_proj", "q_proj")
k_weight_key = self.weight_key.replace("qkv_proj", "k_proj")
v_weight_key = self.weight_key.replace("qkv_proj", "v_proj")
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.tp_size:
sharedkv_index = (
self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads
) // self.tp_size
sharedkv_start = sharedkv_index * self.head_dim
sharedkv_end = sharedkv_start + self.head_dim
k_tensor = k_tensor[:, sharedkv_start:sharedkv_end]
v_tensor = v_tensor[:, sharedkv_start:sharedkv_end]
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape(
[
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * (self.head_dim),
self.hidden_size,
]
)
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
self.quant_method.process_loaded_weights(self, weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, "weight_key should not be None."
# qkv fused in disk
if self.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)
# bias
if self.with_bias:
if self.bias_key in state_dict.keys():
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.bias.set_value(bias_tensor)
else:
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
v_bias_key = self.bias_key.replace("qkv_proj", "v_proj")
q_bias = get_tensor(state_dict.pop(q_bias_key))
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
self.bias.set_value(qkv_bias)
class RowParallelLinear(LinearBase):
"""
RowParallelLinear Layer.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False,
weight_dtype: str = "",
layer_id: int = -1,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.layer_id = layer_id
self.split_token = (
fd_config.parallel_config.use_sequence_parallel_moe
and layer_id >= fd_config.model_config.moe_layer_start_index
and layer_id < fd_config.model_config.num_hidden_layers
)
# Split input_size when using TP inference.
if self.split_token:
self.input_size = input_size
else:
self.input_size = divide(input_size, self.tp_size)
self.output_size = output_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=self.input_size,
output_size=self.output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
assert self.quant_method is not None
create_weight_kwargs = dict(
layer=self,
output_dim=None if self.split_token else False,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
if self.tp_size > 1:
create_weight_kwargs["split_axis"] = 0
create_weight_kwargs["is_distributed"] = True
self.quant_method.create_weights(**create_weight_kwargs)
self.reduce_results = reduce_results and not self.split_token
if self.with_bias and self.tp_size > 1 and self.reduce_results:
set_weight_attrs(self.bias, {"tp_row_bias": True})
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
token_num = x.shape[0]
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
if self.fd_config.scheduler_config.splitwise_role == "decode" and not current_platform.is_xpu():
if not (token_num_pad > token_num):
x_padded = x
else:
x_padded = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_padded[:token_num] = x
out = paddle.zeros([token_num_pad // self.tp_size, x.shape[1] * self.tp_size], x.dtype)
decode_alltoall_transpose(x_padded, out)
else:
if token_num_pad > token_num:
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_new[:token_num, :] = x
x = x_new
out = paddle.zeros_like(x)
paddle.distributed.alltoall(out, x, group=self.tp_group)
out.reshape_([self.tp_size, -1, x.shape[1]])
out = paddle.transpose(out, [1, 0, 2])
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
return out
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
if self.split_token:
x = self.all2all_transpose(x)
out = self.quant_method.apply(self, x)
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
class KVBatchLinear(nn.Layer):
"""
KVBatchLinear Layer for handling combined KV projections with bmm.
"""
def __init__(
self,
fd_config: FDConfig,
kv_b_proj: nn.Layer,
prefix: str = "",
kv_lora_rank: int = None,
num_attention_heads: int = None,
qk_nope_head_dim: int = None,
v_head_dim: int = None,
):
"""
Initializes a KV batch linear layer that internally splits into K and V projections.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
kv_lora_rank (int): LoRA rank for KV projection. Defaults to None.
num_attention_heads (int): Number of attention heads. Defaults to None.
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
v_head_dim (int): Dimension for V projection. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
"""
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.kv_lora_rank = kv_lora_rank
self.num_attention_heads = num_attention_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
# Split num_attention_heads when using TP inference.
self.num_heads_per_partition = divide(num_attention_heads, self.tp_size)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.fd_config = fd_config
if self.fd_config.load_config.load_choices == "default_v1":
self.kv_b_proj = kv_b_proj
else:
self.kv_b_proj = None
self.weight_dtype = self._helper.get_default_dtype()
# Override weight keys to use the combined kv_b_proj
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
def process_weights_after_loading(self):
if self.fd_config.load_config.dynamic_load_weight:
return
w = (
self.kv_b_proj.weight.transpose([1, 0])
if self.fd_config.model_config.model_format == "torch"
else self.kv_b_proj.weight