forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgated_mlp.py
More file actions
203 lines (178 loc) · 7.88 KB
/
gated_mlp.py
File metadata and controls
203 lines (178 loc) · 7.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import Fp4QuantizedTensor
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .swiglu import swiglu
class GatedMLP(nn.Module):
def __init__(
self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = F.silu,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
use_custom_cublas_mm: bool = False,
is_shared_expert: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.activation = activation
config = config or ModelConfig()
self.mapping = config.mapping
if overridden_tp_size is not None:
assert config.mapping.tp_size % overridden_tp_size == 0
tp_size = overridden_tp_size
# "Misuse" pp_size here to perform all-reduce within smaller groups
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
mapping = Mapping(
world_size=tp_size * pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
tp_size=tp_size,
pp_size=pp_size,
)
else:
mapping = config.mapping
# Calculate local intermediate size after tensor parallel sharding
tp_size = mapping.tp_size
local_intermediate_size = self.intermediate_size // tp_size
gateup_shard_indices_mapping = {
'gate': (0, local_intermediate_size),
'up': (local_intermediate_size, local_intermediate_size),
}
self.gate_up_proj = Linear(
self.hidden_size,
self.intermediate_size * 2,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_GATE_UP_LINEAR),
quant_config=config.get_quant_config(),
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
use_custom_cublas_mm=use_custom_cublas_mm,
)
if is_shared_expert:
down_type = LoraModuleType.SHARED_EXPERT_4H_TO_H
h_to_4h_type = LoraModuleType.SHARED_EXPERT_H_TO_4H
gate_type = LoraModuleType.SHARED_EXPERT_GATE
else:
down_type = LoraModuleType.MLP_4H_TO_H
h_to_4h_type = LoraModuleType.MLP_H_TO_4H
gate_type = LoraModuleType.MLP_GATE
self.down_lora = LoraLayer([down_type], [self.hidden_size])
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
reduce_output=reduce_output,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
)
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
# handles them as a single fused operation.
self.splitted_gate_up_lora = LoraLayer([h_to_4h_type, gate_type], [
self.intermediate_size // mapping.tp_size,
self.intermediate_size // mapping.tp_size
])
self.fused_gate_up_lora = LoraLayer(
[LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size // mapping.tp_size])
def _apply_activation(self, x, *, has_lora: bool = False):
if self.activation == F.silu:
if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8:
if has_lora:
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
# TODO: Remove this path when LoRA grouped_gemm supports FP8
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
logger.warning(
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
)
return swiglu(x)
else:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
else:
return swiglu(x)
elif callable(self.activation):
return self.activation(x)
elif self.activation is None:
return x
else:
raise NotImplementedError(
f"Activation {self.activation} not yet implemented for fused GatedMLP"
)
def forward(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if bool(lora_params):
return self.forward_lora(x, all_rank_num_tokens,
final_all_reduce_params, lora_params)
h1 = self.gate_up_proj(x)
h2 = self._apply_activation(h1)
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params,
layer_idx=self.layer_idx)
return output
def forward_lora(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
assert lora_params is not None
assert self.layer_idx is not None, "layer_idx is required for lora"
h1 = self.gate_up_proj(x)
h1_lora = self.splitted_gate_up_lora(x, lora_params, self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h1_lora = self.fused_gate_up_lora(x, lora_params, self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h2 = self._apply_activation(h1, has_lora=True)
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params,
lora_params=lora_params,
layer_idx=self.layer_idx)
return output