Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion cpp/include/tensorrt_llm/runtime/loraModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class LoraModule
kMOE_ROUTER = 16,
kMLP_ROUTER = 17,
kMLP_GATE_UP = 18,
kSHARED_EXPERT_H_TO_4H = 19,
kSHARED_EXPERT_4H_TO_H = 20,
kSHARED_EXPERT_GATE = 21,
};

explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
Expand Down Expand Up @@ -192,7 +195,8 @@ class LoraModule

static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts);
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize = 0,
SizeType32 moeHiddenSize = 0);

static ModuleType constexpr toModuleType(std::string_view const& name)
{
Expand Down Expand Up @@ -234,6 +238,12 @@ class LoraModule
return ModuleType::kMLP_ROUTER;
else if (name == "mlp_gate_up")
return ModuleType::kMLP_GATE_UP;
else if (name == "shared_expert_h_to_4h")
return ModuleType::kSHARED_EXPERT_H_TO_4H;
else if (name == "shared_expert_4h_to_h")
return ModuleType::kSHARED_EXPERT_4H_TO_H;
else if (name == "shared_expert_gate")
return ModuleType::kSHARED_EXPERT_GATE;
else
return ModuleType::kINVALID;
}
Expand Down Expand Up @@ -261,6 +271,9 @@ class LoraModule
case ModuleType::kMOE_ROUTER: return "moe_router";
case ModuleType::kMLP_ROUTER: return "mlp_router";
case ModuleType::kMLP_GATE_UP: return "mlp_gate_up";
case ModuleType::kSHARED_EXPERT_H_TO_4H: return "shared_expert_h_to_4h";
case ModuleType::kSHARED_EXPERT_4H_TO_H: return "shared_expert_4h_to_h";
case ModuleType::kSHARED_EXPERT_GATE: return "shared_expert_gate";
case ModuleType::kINVALID: return "INVALID";
}
return "INVALID";
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE)
.value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER)
.value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER)
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP);
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP)
.value("SHARED_EXPERT_H_TO_4H", tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H)
.value("SHARED_EXPERT_4H_TO_H", tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H)
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE);

nb::class_<tr::LoraModule>(m, "LoraModule")
.def(nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
Expand All @@ -233,7 +236,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"),
nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"),
nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1,
nb::arg("num_experts") = 0);
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0);

nb::class_<tc::QuantMode>(m, "QuantMode")
.def_static("none", &tc::QuantMode::none)
Expand Down
17 changes: 13 additions & 4 deletions cpp/tensorrt_llm/runtime/loraModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ namespace tensorrt_llm::runtime

std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts)
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize,
SizeType32 moeHiddenSize)
{
auto const hidden = hiddenSize * tpSize;
auto const mlpHidden = mlpHiddenSize * tpSize;
auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden;
auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden;
auto const numHeads = numAttentionHeads * tpSize;
auto const numKvHeads = numKvAttentionHeads * tpSize;
auto const attnHeadSize = attentionHeadSize;
Expand Down Expand Up @@ -54,13 +57,19 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHidden, hidden, false, true, 1, -1); break;
// TODO(TRTLLM-379): Support MOE LoRA weights
case ModuleType::kSHARED_EXPERT_H_TO_4H:
case ModuleType::kSHARED_EXPERT_GATE:
modules.emplace_back(t, hidden, sharedExpertHidden, false, true, -1, 0);
break;
case ModuleType::kSHARED_EXPERT_4H_TO_H:
modules.emplace_back(t, sharedExpertHidden, hidden, false, true, 1, -1);
break;
case ModuleType::kMOE_H_TO_4H:
case ModuleType::kMOE_GATE:
modules.emplace_back(t, hidden * numExperts, mlpHidden * numExperts, false, true, -1, 0);
modules.emplace_back(t, hidden * numExperts, moeHidden * numExperts, false, true, -1, 0);
break;
case ModuleType::kMOE_4H_TO_H:
modules.emplace_back(t, mlpHidden * numExperts, hidden * numExperts, false, true, 1, -1);
modules.emplace_back(t, moeHidden * numExperts, hidden * numExperts, false, true, 1, -1);
break;
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/runtime/loraUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void loraValidateRequestTensors(std::optional<std::uint64_t> const& optTaskId,
std::string moduleName(LoraModule::toModuleName(modId));
TLLM_CHECK_WITH_INFO(it != loraModules.end(), "lora module " + moduleName + " not enabled for this model");
TLLM_CHECK_WITH_INFO(it->flattenedInOutSize(adapterSize, isDora) <= weights->getShape().d[2],
"lora_weights has to few values for " + moduleName);
"lora_weights has too few values for " + moduleName);
TLLM_CHECK_WITH_INFO(adapterSize <= maxAdapterSize,
"Invalid low_rank (" + std::to_string(adapterSize) + "). low_rank must be smaller than mMaxLowRank ("
+ std::to_string(maxAdapterSize) + ")");
Expand Down
17 changes: 15 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
dtype=config.torch_dtype,
config=model_config,
reduce_output=False,
layer_idx=layer_idx,
is_shared_expert=True,
)

self.shared_expert_gate = Linear(self.hidden_dim,
Expand All @@ -190,6 +192,7 @@ def forward(
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams] = None,
do_finalize: Optional[bool] = True,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
Expand Down Expand Up @@ -221,7 +224,10 @@ def _compute_routed_output():
return final_hidden_states

def _compute_shared_output():
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = self.shared_expert(
hidden_states,
lora_params=lora_params,
)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output
return shared_expert_output
Expand Down Expand Up @@ -902,6 +908,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
Expand Down Expand Up @@ -950,6 +957,7 @@ def forward(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
do_finalize=do_finalize,
lora_params=lora_params,
)
if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
Expand Down Expand Up @@ -1061,6 +1069,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:

Expand All @@ -1078,6 +1087,7 @@ def forward(
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not self.disable_attn_allreduce),
lora_params=lora_params,
**kwargs,
)

Expand Down Expand Up @@ -1109,6 +1119,7 @@ def forward(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
do_finalize=do_finalize,
lora_params=lora_params,
)

if self.fusion_config.POST_MOE_FUSION:
Expand Down Expand Up @@ -1213,6 +1224,7 @@ def forward(
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -1237,7 +1249,8 @@ def forward(
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
mamba_metadata=mamba_metadata)
mamba_metadata=mamba_metadata,
lora_params=lora_params)
return hidden_states


Expand Down
20 changes: 17 additions & 3 deletions tensorrt_llm/_torch/models/modeling_qwen_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
bias=config.mlp_bias if hasattr(config, 'mlp_bias') else False,
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
is_shared_expert=True,
)

self.shared_expert_gate = Linear(self.hidden_dim,
Expand All @@ -78,6 +80,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
Expand All @@ -91,7 +94,10 @@ def forward(
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = self.shared_expert(
hidden_states,
lora_params=lora_params,
)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output

Expand Down Expand Up @@ -161,6 +167,7 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
Expand All @@ -175,13 +182,18 @@ def forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
lora_params=lora_params,
**kwargs,
)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states, attn_metadata)
hidden_states = self.mlp(
hidden_states,
attn_metadata,
lora_params=lora_params,
)
return hidden_states, residual


Expand Down Expand Up @@ -217,6 +229,7 @@ def forward(
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -234,7 +247,8 @@ def forward(
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
residual=residual,
lora_params=lora_params)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
Expand Down
22 changes: 15 additions & 7 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
use_custom_cublas_mm: bool = False,
is_shared_expert: bool = False,
):

super().__init__()
Expand Down Expand Up @@ -87,8 +88,16 @@ def __init__(
use_custom_cublas_mm=use_custom_cublas_mm,
)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
if is_shared_expert:
down_type = LoraModuleType.SHARED_EXPERT_4H_TO_H
h_to_4h_type = LoraModuleType.SHARED_EXPERT_H_TO_4H
gate_type = LoraModuleType.SHARED_EXPERT_GATE
else:
down_type = LoraModuleType.MLP_4H_TO_H
h_to_4h_type = LoraModuleType.MLP_H_TO_4H
gate_type = LoraModuleType.MLP_GATE

self.down_lora = LoraLayer([down_type], [self.hidden_size])

self.down_proj = Linear(
self.intermediate_size,
Expand All @@ -111,11 +120,10 @@ def __init__(
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
# handles them as a single fused operation.
self.splitted_gate_up_lora = LoraLayer(
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], [
self.intermediate_size // mapping.tp_size,
self.intermediate_size // mapping.tp_size
])
self.splitted_gate_up_lora = LoraLayer([h_to_4h_type, gate_type], [
self.intermediate_size // mapping.tp_size,
self.intermediate_size // mapping.tp_size
])
self.fused_gate_up_lora = LoraLayer(
[LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size // mapping.tp_size])
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/peft/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class LoraModuleType(IntEnum):
MLP_ROUTER = 17 # MLP router
MLP_GATE_UP = 18 # Combined gate and up projections

SHARED_EXPERT_H_TO_4H = 19 # Shared expert first projection
SHARED_EXPERT_4H_TO_H = 20 # Shared expert second projection
SHARED_EXPERT_GATE = 21 # Shared expert gate projection

def __str__(self):
"""Return the name of the enum value."""
return self.name
Expand Down
Loading
Loading