|
15 | 15 | */ |
16 | 16 |
|
17 | 17 | #include "tensorrt_llm/runtime/loraModule.h" |
| 18 | +#include "tensorrt_llm/common/assert.h" |
18 | 19 |
|
19 | 20 | namespace tensorrt_llm::runtime |
20 | 21 | { |
21 | 22 |
|
22 | 23 | std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames, |
23 | 24 | SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads, |
24 | 25 | SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize, |
25 | | - SizeType32 moeHiddenSize) |
| 26 | + SizeType32 moeHiddenSize, SizeType32 mambaInProjSize, SizeType32 mambaInnerSize, SizeType32 moeLatentSize) |
26 | 27 | { |
27 | 28 | auto const hidden = hiddenSize * tpSize; |
28 | 29 | auto const mlpHidden = mlpHiddenSize * tpSize; |
29 | 30 | auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden; |
30 | 31 | auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden; |
| 32 | + // Mamba dimensions: in_proj outputs d_in_proj, out_proj inputs d_inner |
| 33 | + TLLM_CHECK_WITH_INFO((mambaInProjSize > 0) == (mambaInnerSize > 0), |
| 34 | + "mambaInProjSize and mambaInnerSize must both be zero or both be non-zero"); |
| 35 | + auto const mambaInProj = mambaInProjSize * tpSize; |
| 36 | + auto const mambaInner = mambaInnerSize * tpSize; |
| 37 | + // MoE latent projections are replicated (not TP-sharded), so moeLatentSize |
| 38 | + // is the actual per-GPU dimension and should not be scaled by tpSize. |
| 39 | + auto const moeLatent = moeLatentSize; |
31 | 40 | auto const numHeads = numAttentionHeads * tpSize; |
32 | 41 | auto const numKvHeads = numKvAttentionHeads * tpSize; |
33 | 42 | auto const attnHeadSize = attentionHeadSize; |
@@ -74,6 +83,12 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c |
74 | 83 | case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break; |
75 | 84 | case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break; |
76 | 85 | case ModuleType::kMLP_GATE_UP: modules.emplace_back(t, hidden, 2 * mlpHidden, false, true, -1, 0); break; |
| 86 | + // Mamba modules: in_proj (hidden -> d_in_proj), out_proj (d_inner -> hidden) |
| 87 | + case ModuleType::kMAMBA_IN_PROJ: modules.emplace_back(t, hidden, mambaInProj, false, true, -1, 0); break; |
| 88 | + case ModuleType::kMAMBA_OUT_PROJ: modules.emplace_back(t, mambaInner, hidden, false, true, 1, -1); break; |
| 89 | + // MoE latent projections: replicated (not TP-sharded), no TP split dims |
| 90 | + case ModuleType::kMOE_LATENT_FC1: modules.emplace_back(t, hidden, moeLatent, false, true, -1, -1); break; |
| 91 | + case ModuleType::kMOE_LATENT_FC2: modules.emplace_back(t, moeLatent, hidden, false, true, -1, -1); break; |
77 | 92 | case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName); |
78 | 93 | } |
79 | 94 | } |
|
0 commit comments