Skip to content

Commit 3a699fe

Browse files
committed
[TRTLLM-10232][feat] Support LoRA adapter for nemotron-h models
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 20fc52c commit 3a699fe

File tree

16 files changed

+473
-48
lines changed

16 files changed

+473
-48
lines changed

cpp/include/tensorrt_llm/runtime/loraModule.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class LoraModule
5353
kSHARED_EXPERT_H_TO_4H = 19,
5454
kSHARED_EXPERT_4H_TO_H = 20,
5555
kSHARED_EXPERT_GATE = 21,
56+
kMAMBA_IN_PROJ = 22,
57+
kMAMBA_OUT_PROJ = 23,
58+
kMOE_LATENT_UP = 24,
59+
kMOE_LATENT_DOWN = 25,
5660
};
5761

5862
explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
@@ -196,7 +200,7 @@ class LoraModule
196200
static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
197201
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
198202
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize = 0,
199-
SizeType32 moeHiddenSize = 0);
203+
SizeType32 moeHiddenSize = 0, SizeType32 mambaInProjSize = 0, SizeType32 mambaInnerSize = 0);
200204

201205
static ModuleType constexpr toModuleType(std::string_view const& name)
202206
{
@@ -244,6 +248,14 @@ class LoraModule
244248
return ModuleType::kSHARED_EXPERT_4H_TO_H;
245249
else if (name == "shared_expert_gate")
246250
return ModuleType::kSHARED_EXPERT_GATE;
251+
else if (name == "mamba_in_proj")
252+
return ModuleType::kMAMBA_IN_PROJ;
253+
else if (name == "mamba_out_proj")
254+
return ModuleType::kMAMBA_OUT_PROJ;
255+
else if (name == "moe_latent_up")
256+
return ModuleType::kMOE_LATENT_UP;
257+
else if (name == "moe_latent_down")
258+
return ModuleType::kMOE_LATENT_DOWN;
247259
else
248260
return ModuleType::kINVALID;
249261
}
@@ -274,6 +286,10 @@ class LoraModule
274286
case ModuleType::kSHARED_EXPERT_H_TO_4H: return "shared_expert_h_to_4h";
275287
case ModuleType::kSHARED_EXPERT_4H_TO_H: return "shared_expert_4h_to_h";
276288
case ModuleType::kSHARED_EXPERT_GATE: return "shared_expert_gate";
289+
case ModuleType::kMAMBA_IN_PROJ: return "mamba_in_proj";
290+
case ModuleType::kMAMBA_OUT_PROJ: return "mamba_out_proj";
291+
case ModuleType::kMOE_LATENT_UP: return "moe_latent_up";
292+
case ModuleType::kMOE_LATENT_DOWN: return "moe_latent_down";
277293
case ModuleType::kINVALID: return "INVALID";
278294
}
279295
return "INVALID";

cpp/include/tensorrt_llm/runtime/modelConfig.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,47 @@ class ModelConfig
228228
return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
229229
}
230230

231+
// Get the first LoRA layer index for a given PP rank.
232+
// Distributes extra layers to lower ranks when num_lora_layers is not evenly divisible by PP size.
233+
[[nodiscard]] SizeType32 getFirstLoraLayer(
234+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
235+
{
236+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
237+
if (mNbLoraLayers > 0)
238+
{
239+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
240+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
241+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
242+
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
243+
}
244+
// Fall back to attention layer distribution
245+
return countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
246+
}
247+
248+
// Get number of layers that can have LoRA applied for the given PP rank.
249+
// For hybrid models (e.g., Nemotron-H with Mamba + Attention), this may differ from num_attention_layers
250+
// because LoRA can be applied to non-attention layers (e.g., Mamba in_proj/out_proj).
251+
// Handles uneven PP splits by distributing extra layers to lower ranks.
252+
[[nodiscard]] SizeType32 getNbLoraLayers(
253+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
254+
{
255+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
256+
// If mNbLoraLayers is set (non-zero), use it with proper PP distribution
257+
if (mNbLoraLayers > 0)
258+
{
259+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
260+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
261+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
262+
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
263+
}
264+
return getNbAttentionLayers(pipelineParallelism, pipelineParallelismRank);
265+
}
266+
267+
void setNbLoraLayers(SizeType32 nbLoraLayers)
268+
{
269+
mNbLoraLayers = nbLoraLayers;
270+
}
271+
231272
[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
232273
{
233274
return mNbHeads;
@@ -922,6 +963,8 @@ class ModelConfig
922963
std::vector<LoraModule> mLoraModules;
923964
SizeType32 mMlpHiddenSize;
924965
SizeType32 mMaxLoraRank;
966+
// Number of layers that can have LoRA applied (for hybrid models this may be > num_attention_layers)
967+
SizeType32 mNbLoraLayers{0};
925968

926969
std::optional<RnnConfig> mRnnConfig;
927970

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
220220
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP)
221221
.value("SHARED_EXPERT_H_TO_4H", tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H)
222222
.value("SHARED_EXPERT_4H_TO_H", tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H)
223-
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE);
223+
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE)
224+
.value("MAMBA_IN_PROJ", tr::LoraModule::ModuleType::kMAMBA_IN_PROJ)
225+
.value("MAMBA_OUT_PROJ", tr::LoraModule::ModuleType::kMAMBA_OUT_PROJ)
226+
.value("MOE_LATENT_UP", tr::LoraModule::ModuleType::kMOE_LATENT_UP)
227+
.value("MOE_LATENT_DOWN", tr::LoraModule::ModuleType::kMOE_LATENT_DOWN);
224228

225229
nb::class_<tr::LoraModule>(m, "LoraModule")
226230
.def(nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
@@ -236,7 +240,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
236240
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"),
237241
nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"),
238242
nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1,
239-
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0);
243+
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0,
244+
nb::arg("mamba_in_proj_size") = 0, nb::arg("mamba_inner_size") = 0);
240245

241246
nb::class_<tc::QuantMode>(m, "QuantMode")
242247
.def_static("none", &tc::QuantMode::none)
@@ -342,6 +347,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
342347
.def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
343348
.def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
344349
.def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
350+
.def("num_lora_layers", &tr::ModelConfig::getNbLoraLayers, nb::arg("pipeline_parallelism") = 1,
351+
nb::arg("pipeline_parallelism_rank") = 0)
352+
.def("first_lora_layer", &tr::ModelConfig::getFirstLoraLayer, nb::arg("pipeline_parallelism") = 1,
353+
nb::arg("pipeline_parallelism_rank") = 0)
354+
.def("set_num_lora_layers", &tr::ModelConfig::setNbLoraLayers, nb::arg("num_lora_layers"))
345355
.def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
346356

347357
nb::class_<tr::WorldConfig>(m, "WorldConfig")

cpp/tensorrt_llm/runtime/loraCache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,10 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const
454454
SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const
455455
{
456456
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
457-
auto const localNumLayers = mModelConfig.getNbAttentionLayers(
458-
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
459-
auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers;
457+
auto const localNumLayers
458+
= mModelConfig.getNbLoraLayers(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
459+
auto const firstLayerId
460+
= mModelConfig.getFirstLoraLayer(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
460461
auto const lastLayerId = firstLayerId + localNumLayers;
461462

462463
SizeType32 currPage = 0;
@@ -579,9 +580,8 @@ std::vector<LoraCache::TaskLayerModuleConfig> LoraCache::copyToPages(TensorPtr s
579580
auto const tpRank = worldConfig.getTensorParallelRank();
580581
auto const ppSize = worldConfig.getPipelineParallelism();
581582
auto const ppRank = worldConfig.getPipelineParallelRank();
582-
// TODO(oargov): why *attention* layers?
583-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
584-
auto const firstLayerId = ppRank * localNumLayers;
583+
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
584+
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);
585585
auto const lastLayerId = firstLayerId + localNumLayers;
586586

587587
SizeType32 currPage = 0;

cpp/tensorrt_llm/runtime/loraManager.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
7272

7373
auto const ppSize = worldConfig.getPipelineParallelism();
7474
auto const ppRank = worldConfig.getPipelineParallelRank();
75-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
76-
auto const firstLayerId = ppRank * localNumLayers;
75+
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
76+
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);
7777

7878
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
7979
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
@@ -124,8 +124,9 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
124124
{
125125
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
126126
auto localNbLayers
127-
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
128-
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
127+
= modelConfig.getNbLoraLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
128+
auto firstLayerId
129+
= modelConfig.getFirstLoraLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
129130

130131
for (auto const& [modId, mod] : mModuleIdToModule)
131132
{

cpp/tensorrt_llm/runtime/loraModule.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ namespace tensorrt_llm::runtime
2222
std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
2323
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
2424
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize,
25-
SizeType32 moeHiddenSize)
25+
SizeType32 moeHiddenSize, SizeType32 mambaInProjSize, SizeType32 mambaInnerSize)
2626
{
2727
auto const hidden = hiddenSize * tpSize;
2828
auto const mlpHidden = mlpHiddenSize * tpSize;
2929
auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden;
3030
auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden;
31+
// Mamba dimensions: in_proj outputs d_in_proj, out_proj inputs d_inner
32+
// Fall back to mlpHidden if not specified (for backward compatibility)
33+
auto const mambaInProj = mambaInProjSize > 0 ? mambaInProjSize * tpSize : mlpHidden;
34+
auto const mambaInner = mambaInnerSize > 0 ? mambaInnerSize * tpSize : mlpHidden;
3135
auto const numHeads = numAttentionHeads * tpSize;
3236
auto const numKvHeads = numKvAttentionHeads * tpSize;
3337
auto const attnHeadSize = attentionHeadSize;
@@ -74,6 +78,12 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
7478
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
7579
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
7680
case ModuleType::kMLP_GATE_UP: modules.emplace_back(t, hidden, 2 * mlpHidden, false, true, -1, 0); break;
81+
// Mamba modules: in_proj (hidden -> d_in_proj), out_proj (d_inner -> hidden)
82+
case ModuleType::kMAMBA_IN_PROJ: modules.emplace_back(t, hidden, mambaInProj, false, true, -1, 0); break;
83+
case ModuleType::kMAMBA_OUT_PROJ: modules.emplace_back(t, mambaInner, hidden, false, true, 1, -1); break;
84+
// MoE latent projections: up expands to moe_hidden, down contracts back
85+
case ModuleType::kMOE_LATENT_UP: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
86+
case ModuleType::kMOE_LATENT_DOWN: modules.emplace_back(t, mlpHidden, hidden, false, true, 1, -1); break;
7787
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
7888
}
7989
}

cpp/tensorrt_llm/runtime/loraUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void loraValidateRequestTensors(std::optional<std::uint64_t> const& optTaskId,
8484
? config
8585
: ITensor::view(config, ITensor::makeShape({config->getShape().d[1], config->getShape().d[2]}));
8686

87-
SizeType32 nbModelLayers = modelConfig.getNbAttentionLayers();
87+
SizeType32 nbModelLayers = modelConfig.getNbLoraLayers();
8888
TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(),
8989
"Expected lora weights to be the same data type as base model");
9090

tensorrt_llm/_torch/model_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,12 @@ def get_bindings_model_config(self,
686686
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
687687
model_config_cpp.set_num_kv_heads(num_kv_heads)
688688

689+
# For hybrid models (e.g., Nemotron-H with Mamba + Attention), LoRA can be applied
690+
# to non-attention layers (e.g., Mamba in_proj/out_proj). Set num_lora_layers to
691+
# total layers so the C++ LoRA validation accepts all layer indices.
692+
if is_nemotron_hybrid(self.pretrained_config):
693+
model_config_cpp.set_num_lora_layers(num_layers)
694+
689695
mlp_hidden_size = None
690696
if self.pretrained_config.intermediate_size is not None:
691697
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size

0 commit comments

Comments
 (0)