Skip to content

Commit dbb1c8c

Browse files
authored
[TRTLLM-10232][feat] Support LoRA adapter for nemotron-h models (NVIDIA#12154)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 2b4f54c commit dbb1c8c

File tree

16 files changed

+545
-44
lines changed

16 files changed

+545
-44
lines changed

cpp/include/tensorrt_llm/runtime/loraModule.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class LoraModule
5353
kSHARED_EXPERT_H_TO_4H = 19,
5454
kSHARED_EXPERT_4H_TO_H = 20,
5555
kSHARED_EXPERT_GATE = 21,
56+
kMAMBA_IN_PROJ = 22,
57+
kMAMBA_OUT_PROJ = 23,
58+
kMOE_LATENT_FC1 = 24,
59+
kMOE_LATENT_FC2 = 25,
5660
};
5761

5862
explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
@@ -196,7 +200,8 @@ class LoraModule
196200
static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
197201
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
198202
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize = 0,
199-
SizeType32 moeHiddenSize = 0);
203+
SizeType32 moeHiddenSize = 0, SizeType32 mambaInProjSize = 0, SizeType32 mambaInnerSize = 0,
204+
SizeType32 moeLatentSize = 0);
200205

201206
static ModuleType constexpr toModuleType(std::string_view const& name)
202207
{
@@ -244,6 +249,14 @@ class LoraModule
244249
return ModuleType::kSHARED_EXPERT_4H_TO_H;
245250
else if (name == "shared_expert_gate")
246251
return ModuleType::kSHARED_EXPERT_GATE;
252+
else if (name == "mamba_in_proj")
253+
return ModuleType::kMAMBA_IN_PROJ;
254+
else if (name == "mamba_out_proj")
255+
return ModuleType::kMAMBA_OUT_PROJ;
256+
else if (name == "moe_latent_fc1")
257+
return ModuleType::kMOE_LATENT_FC1;
258+
else if (name == "moe_latent_fc2")
259+
return ModuleType::kMOE_LATENT_FC2;
247260
else
248261
return ModuleType::kINVALID;
249262
}
@@ -274,6 +287,10 @@ class LoraModule
274287
case ModuleType::kSHARED_EXPERT_H_TO_4H: return "shared_expert_h_to_4h";
275288
case ModuleType::kSHARED_EXPERT_4H_TO_H: return "shared_expert_4h_to_h";
276289
case ModuleType::kSHARED_EXPERT_GATE: return "shared_expert_gate";
290+
case ModuleType::kMAMBA_IN_PROJ: return "mamba_in_proj";
291+
case ModuleType::kMAMBA_OUT_PROJ: return "mamba_out_proj";
292+
case ModuleType::kMOE_LATENT_FC1: return "moe_latent_fc1";
293+
case ModuleType::kMOE_LATENT_FC2: return "moe_latent_fc2";
277294
case ModuleType::kINVALID: return "INVALID";
278295
}
279296
return "INVALID";

cpp/include/tensorrt_llm/runtime/modelConfig.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,61 @@ class ModelConfig
228228
return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
229229
}
230230

231+
// Get the first LoRA layer index for a given PP rank.
232+
// Distributes extra layers to lower ranks when num_lora_layers is not evenly divisible by PP size.
233+
[[nodiscard]] SizeType32 getFirstLoraLayer(
234+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
235+
{
236+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
237+
if (mNbLoraLayers > 0)
238+
{
239+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
240+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
241+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
242+
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
243+
}
244+
// Fall back to attention layer distribution
245+
if (mLayerTypes.empty())
246+
{
247+
// When layer types aren't populated, assume uniform attention layer distribution
248+
auto const numBaseLayers = mNbAttentionLayers / pipelineParallelism;
249+
auto const numExtraLayers = mNbAttentionLayers % pipelineParallelism;
250+
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
251+
}
252+
return countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
253+
}
254+
255+
// Get number of layers that can have LoRA applied for the given PP rank.
256+
// For hybrid models (e.g., Nemotron-H with Mamba + Attention), this may differ from num_attention_layers
257+
// because LoRA can be applied to non-attention layers (e.g., Mamba in_proj/out_proj).
258+
// Handles uneven PP splits by distributing extra layers to lower ranks.
259+
[[nodiscard]] SizeType32 getNbLoraLayers(
260+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
261+
{
262+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
263+
// If mNbLoraLayers is set (non-zero), use it with proper PP distribution
264+
if (mNbLoraLayers > 0)
265+
{
266+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
267+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
268+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
269+
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
270+
}
271+
// Fall back to attention layer distribution, matching getFirstLoraLayer's formula
272+
if (mLayerTypes.empty())
273+
{
274+
auto const numBaseLayers = mNbAttentionLayers / pipelineParallelism;
275+
auto const numExtraLayers = mNbAttentionLayers % pipelineParallelism;
276+
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
277+
}
278+
return countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
279+
}
280+
281+
void setNbLoraLayers(SizeType32 nbLoraLayers)
282+
{
283+
mNbLoraLayers = nbLoraLayers;
284+
}
285+
231286
[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
232287
{
233288
return mNbHeads;
@@ -922,6 +977,8 @@ class ModelConfig
922977
std::vector<LoraModule> mLoraModules;
923978
SizeType32 mMlpHiddenSize;
924979
SizeType32 mMaxLoraRank;
980+
// Number of layers that can have LoRA applied (for hybrid models this may be > num_attention_layers)
981+
SizeType32 mNbLoraLayers{0};
925982

926983
std::optional<RnnConfig> mRnnConfig;
927984

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
220220
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP)
221221
.value("SHARED_EXPERT_H_TO_4H", tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H)
222222
.value("SHARED_EXPERT_4H_TO_H", tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H)
223-
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE);
223+
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE)
224+
.value("MAMBA_IN_PROJ", tr::LoraModule::ModuleType::kMAMBA_IN_PROJ)
225+
.value("MAMBA_OUT_PROJ", tr::LoraModule::ModuleType::kMAMBA_OUT_PROJ)
226+
.value("MOE_LATENT_FC1", tr::LoraModule::ModuleType::kMOE_LATENT_FC1)
227+
.value("MOE_LATENT_FC2", tr::LoraModule::ModuleType::kMOE_LATENT_FC2);
224228

225229
nb::class_<tr::LoraModule>(m, "LoraModule")
226230
.def(nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
@@ -236,7 +240,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
236240
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"),
237241
nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"),
238242
nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1,
239-
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0);
243+
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0,
244+
nb::arg("mamba_in_proj_size") = 0, nb::arg("mamba_inner_size") = 0, nb::arg("moe_latent_size") = 0);
240245

241246
nb::class_<tc::QuantMode>(m, "QuantMode")
242247
.def_static("none", &tc::QuantMode::none)
@@ -342,6 +347,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
342347
.def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
343348
.def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
344349
.def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
350+
.def("num_lora_layers", &tr::ModelConfig::getNbLoraLayers, nb::arg("pipeline_parallelism") = 1,
351+
nb::arg("pipeline_parallelism_rank") = 0)
352+
.def("first_lora_layer", &tr::ModelConfig::getFirstLoraLayer, nb::arg("pipeline_parallelism") = 1,
353+
nb::arg("pipeline_parallelism_rank") = 0)
354+
.def("set_num_lora_layers", &tr::ModelConfig::setNbLoraLayers, nb::arg("num_lora_layers"))
345355
.def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
346356

347357
nb::class_<tr::WorldConfig>(m, "WorldConfig")

cpp/tensorrt_llm/runtime/loraCache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,10 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const
454454
SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const
455455
{
456456
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
457-
auto const localNumLayers = mModelConfig.getNbAttentionLayers(
458-
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
459-
auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers;
457+
auto const localNumLayers
458+
= mModelConfig.getNbLoraLayers(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
459+
auto const firstLayerId
460+
= mModelConfig.getFirstLoraLayer(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
460461
auto const lastLayerId = firstLayerId + localNumLayers;
461462

462463
SizeType32 currPage = 0;
@@ -579,9 +580,8 @@ std::vector<LoraCache::TaskLayerModuleConfig> LoraCache::copyToPages(TensorPtr s
579580
auto const tpRank = worldConfig.getTensorParallelRank();
580581
auto const ppSize = worldConfig.getPipelineParallelism();
581582
auto const ppRank = worldConfig.getPipelineParallelRank();
582-
// TODO(oargov): why *attention* layers?
583-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
584-
auto const firstLayerId = ppRank * localNumLayers;
583+
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
584+
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);
585585
auto const lastLayerId = firstLayerId + localNumLayers;
586586

587587
SizeType32 currPage = 0;

cpp/tensorrt_llm/runtime/loraManager.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
7272

7373
auto const ppSize = worldConfig.getPipelineParallelism();
7474
auto const ppRank = worldConfig.getPipelineParallelRank();
75-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
76-
auto const firstLayerId = ppRank * localNumLayers;
75+
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);
7776

7877
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
7978
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
@@ -123,9 +122,8 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
123122
ModelConfig const& modelConfig, WorldConfig const& worldConfig) const
124123
{
125124
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
126-
auto localNbLayers
127-
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
128-
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
125+
auto firstLayerId
126+
= modelConfig.getFirstLoraLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
129127

130128
for (auto const& [modId, mod] : mModuleIdToModule)
131129
{

cpp/tensorrt_llm/runtime/loraModule.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,28 @@
1515
*/
1616

1717
#include "tensorrt_llm/runtime/loraModule.h"
18+
#include "tensorrt_llm/common/assert.h"
1819

1920
namespace tensorrt_llm::runtime
2021
{
2122

2223
std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
2324
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
2425
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize,
25-
SizeType32 moeHiddenSize)
26+
SizeType32 moeHiddenSize, SizeType32 mambaInProjSize, SizeType32 mambaInnerSize, SizeType32 moeLatentSize)
2627
{
2728
auto const hidden = hiddenSize * tpSize;
2829
auto const mlpHidden = mlpHiddenSize * tpSize;
2930
auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden;
3031
auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden;
32+
// Mamba dimensions: in_proj outputs d_in_proj, out_proj inputs d_inner
33+
TLLM_CHECK_WITH_INFO((mambaInProjSize > 0) == (mambaInnerSize > 0),
34+
"mambaInProjSize and mambaInnerSize must both be zero or both be non-zero");
35+
auto const mambaInProj = mambaInProjSize * tpSize;
36+
auto const mambaInner = mambaInnerSize * tpSize;
37+
// MoE latent projections are replicated (not TP-sharded), so moeLatentSize
38+
// is the actual per-GPU dimension and should not be scaled by tpSize.
39+
auto const moeLatent = moeLatentSize;
3140
auto const numHeads = numAttentionHeads * tpSize;
3241
auto const numKvHeads = numKvAttentionHeads * tpSize;
3342
auto const attnHeadSize = attentionHeadSize;
@@ -74,6 +83,12 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
7483
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
7584
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
7685
case ModuleType::kMLP_GATE_UP: modules.emplace_back(t, hidden, 2 * mlpHidden, false, true, -1, 0); break;
86+
// Mamba modules: in_proj (hidden -> d_in_proj), out_proj (d_inner -> hidden)
87+
case ModuleType::kMAMBA_IN_PROJ: modules.emplace_back(t, hidden, mambaInProj, false, true, -1, 0); break;
88+
case ModuleType::kMAMBA_OUT_PROJ: modules.emplace_back(t, mambaInner, hidden, false, true, 1, -1); break;
89+
// MoE latent projections: replicated (not TP-sharded), no TP split dims
90+
case ModuleType::kMOE_LATENT_FC1: modules.emplace_back(t, hidden, moeLatent, false, true, -1, -1); break;
91+
case ModuleType::kMOE_LATENT_FC2: modules.emplace_back(t, moeLatent, hidden, false, true, -1, -1); break;
7792
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
7893
}
7994
}

cpp/tensorrt_llm/runtime/loraUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void loraValidateRequestTensors(std::optional<std::uint64_t> const& optTaskId,
8484
? config
8585
: ITensor::view(config, ITensor::makeShape({config->getShape().d[1], config->getShape().d[2]}));
8686

87-
SizeType32 nbModelLayers = modelConfig.getNbAttentionLayers();
87+
SizeType32 nbModelLayers = modelConfig.getNbLoraLayers();
8888
TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(),
8989
"Expected lora weights to be the same data type as base model");
9090

tensorrt_llm/_torch/model_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,12 @@ def ceil_div(a, b):
702702
attn_tp_size * attn_cp_size)
703703
model_config_cpp.set_num_kv_heads(num_kv_heads)
704704

705+
# For hybrid models (e.g., Nemotron-H with Mamba + Attention), LoRA can be applied
706+
# to non-attention layers (e.g., Mamba in_proj/out_proj). Set num_lora_layers to
707+
# total layers so the C++ LoRA validation accepts all layer indices.
708+
if is_nemotron_hybrid(self.pretrained_config):
709+
model_config_cpp.set_num_lora_layers(num_layers)
710+
705711
mlp_hidden_size = None
706712
if self.pretrained_config.intermediate_size is not None:
707713
mlp_hidden_size = ceil_div(self.pretrained_config.intermediate_size,

0 commit comments

Comments
 (0)