Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cpp/include/tensorrt_llm/runtime/loraModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class LoraModule
kSHARED_EXPERT_H_TO_4H = 19,
kSHARED_EXPERT_4H_TO_H = 20,
kSHARED_EXPERT_GATE = 21,
kMAMBA_IN_PROJ = 22,
kMAMBA_OUT_PROJ = 23,
kMOE_LATENT_FC1 = 24,
kMOE_LATENT_FC2 = 25,
};

explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
Expand Down Expand Up @@ -196,7 +200,8 @@ class LoraModule
static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize = 0,
SizeType32 moeHiddenSize = 0);
SizeType32 moeHiddenSize = 0, SizeType32 mambaInProjSize = 0, SizeType32 mambaInnerSize = 0,
SizeType32 moeLatentSize = 0);

static ModuleType constexpr toModuleType(std::string_view const& name)
{
Expand Down Expand Up @@ -244,6 +249,14 @@ class LoraModule
return ModuleType::kSHARED_EXPERT_4H_TO_H;
else if (name == "shared_expert_gate")
return ModuleType::kSHARED_EXPERT_GATE;
else if (name == "mamba_in_proj")
return ModuleType::kMAMBA_IN_PROJ;
else if (name == "mamba_out_proj")
return ModuleType::kMAMBA_OUT_PROJ;
else if (name == "moe_latent_fc1")
return ModuleType::kMOE_LATENT_FC1;
else if (name == "moe_latent_fc2")
return ModuleType::kMOE_LATENT_FC2;
else
return ModuleType::kINVALID;
}
Expand Down Expand Up @@ -274,6 +287,10 @@ class LoraModule
case ModuleType::kSHARED_EXPERT_H_TO_4H: return "shared_expert_h_to_4h";
case ModuleType::kSHARED_EXPERT_4H_TO_H: return "shared_expert_4h_to_h";
case ModuleType::kSHARED_EXPERT_GATE: return "shared_expert_gate";
case ModuleType::kMAMBA_IN_PROJ: return "mamba_in_proj";
case ModuleType::kMAMBA_OUT_PROJ: return "mamba_out_proj";
case ModuleType::kMOE_LATENT_FC1: return "moe_latent_fc1";
case ModuleType::kMOE_LATENT_FC2: return "moe_latent_fc2";
case ModuleType::kINVALID: return "INVALID";
}
return "INVALID";
Expand Down
57 changes: 57 additions & 0 deletions cpp/include/tensorrt_llm/runtime/modelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,61 @@ class ModelConfig
return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
}

// Get the first LoRA layer index for a given PP rank.
// Distributes extra layers to lower ranks when num_lora_layers is not evenly divisible by PP size.
[[nodiscard]] SizeType32 getFirstLoraLayer(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
if (mNbLoraLayers > 0)
{
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
}
// Fall back to attention layer distribution
if (mLayerTypes.empty())
{
// When layer types aren't populated, assume uniform attention layer distribution
auto const numBaseLayers = mNbAttentionLayers / pipelineParallelism;
auto const numExtraLayers = mNbAttentionLayers % pipelineParallelism;
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
}
return countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
}

// Get number of layers that can have LoRA applied for the given PP rank.
// For hybrid models (e.g., Nemotron-H with Mamba + Attention), this may differ from num_attention_layers
// because LoRA can be applied to non-attention layers (e.g., Mamba in_proj/out_proj).
// Handles uneven PP splits by distributing extra layers to lower ranks.
[[nodiscard]] SizeType32 getNbLoraLayers(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
// If mNbLoraLayers is set (non-zero), use it with proper PP distribution
if (mNbLoraLayers > 0)
{
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
}
// Fall back to attention layer distribution, matching getFirstLoraLayer's formula
if (mLayerTypes.empty())
{
auto const numBaseLayers = mNbAttentionLayers / pipelineParallelism;
auto const numExtraLayers = mNbAttentionLayers % pipelineParallelism;
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
}
return countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
}

void setNbLoraLayers(SizeType32 nbLoraLayers)
{
mNbLoraLayers = nbLoraLayers;
}

[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
{
return mNbHeads;
Expand Down Expand Up @@ -922,6 +977,8 @@ class ModelConfig
std::vector<LoraModule> mLoraModules;
SizeType32 mMlpHiddenSize;
SizeType32 mMaxLoraRank;
// Number of layers that can have LoRA applied (for hybrid models this may be > num_attention_layers)
SizeType32 mNbLoraLayers{0};

std::optional<RnnConfig> mRnnConfig;

Expand Down
14 changes: 12 additions & 2 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP)
.value("SHARED_EXPERT_H_TO_4H", tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H)
.value("SHARED_EXPERT_4H_TO_H", tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H)
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE);
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE)
.value("MAMBA_IN_PROJ", tr::LoraModule::ModuleType::kMAMBA_IN_PROJ)
.value("MAMBA_OUT_PROJ", tr::LoraModule::ModuleType::kMAMBA_OUT_PROJ)
.value("MOE_LATENT_FC1", tr::LoraModule::ModuleType::kMOE_LATENT_FC1)
.value("MOE_LATENT_FC2", tr::LoraModule::ModuleType::kMOE_LATENT_FC2);

nb::class_<tr::LoraModule>(m, "LoraModule")
.def(nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
Expand All @@ -236,7 +240,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"),
nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"),
nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1,
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0);
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0,
nb::arg("mamba_in_proj_size") = 0, nb::arg("mamba_inner_size") = 0, nb::arg("moe_latent_size") = 0);

nb::class_<tc::QuantMode>(m, "QuantMode")
.def_static("none", &tc::QuantMode::none)
Expand Down Expand Up @@ -342,6 +347,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
.def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
.def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
.def("num_lora_layers", &tr::ModelConfig::getNbLoraLayers, nb::arg("pipeline_parallelism") = 1,
nb::arg("pipeline_parallelism_rank") = 0)
.def("first_lora_layer", &tr::ModelConfig::getFirstLoraLayer, nb::arg("pipeline_parallelism") = 1,
nb::arg("pipeline_parallelism_rank") = 0)
.def("set_num_lora_layers", &tr::ModelConfig::setNbLoraLayers, nb::arg("num_lora_layers"))
.def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);

nb::class_<tr::WorldConfig>(m, "WorldConfig")
Expand Down
12 changes: 6 additions & 6 deletions cpp/tensorrt_llm/runtime/loraCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const
SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const localNumLayers = mModelConfig.getNbAttentionLayers(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers;
auto const localNumLayers
= mModelConfig.getNbLoraLayers(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
auto const firstLayerId
= mModelConfig.getFirstLoraLayer(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
auto const lastLayerId = firstLayerId + localNumLayers;

SizeType32 currPage = 0;
Expand Down Expand Up @@ -579,9 +580,8 @@ std::vector<LoraCache::TaskLayerModuleConfig> LoraCache::copyToPages(TensorPtr s
auto const tpRank = worldConfig.getTensorParallelRank();
auto const ppSize = worldConfig.getPipelineParallelism();
auto const ppRank = worldConfig.getPipelineParallelRank();
// TODO(oargov): why *attention* layers?
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
auto const firstLayerId = ppRank * localNumLayers;
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);
auto const lastLayerId = firstLayerId + localNumLayers;

SizeType32 currPage = 0;
Expand Down
8 changes: 3 additions & 5 deletions cpp/tensorrt_llm/runtime/loraManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes

auto const ppSize = worldConfig.getPipelineParallelism();
auto const ppRank = worldConfig.getPipelineParallelRank();
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
auto const firstLayerId = ppRank * localNumLayers;
auto const firstLayerId = modelConfig.getFirstLoraLayer(ppSize, ppRank);

auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
Expand Down Expand Up @@ -123,9 +122,8 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
ModelConfig const& modelConfig, WorldConfig const& worldConfig) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
auto firstLayerId
= modelConfig.getFirstLoraLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());

for (auto const& [modId, mod] : mModuleIdToModule)
{
Expand Down
17 changes: 16 additions & 1 deletion cpp/tensorrt_llm/runtime/loraModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,28 @@
*/

#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/common/assert.h"

namespace tensorrt_llm::runtime
{

std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize,
SizeType32 moeHiddenSize)
SizeType32 moeHiddenSize, SizeType32 mambaInProjSize, SizeType32 mambaInnerSize, SizeType32 moeLatentSize)
{
auto const hidden = hiddenSize * tpSize;
auto const mlpHidden = mlpHiddenSize * tpSize;
auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden;
auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden;
// Mamba dimensions: in_proj outputs d_in_proj, out_proj inputs d_inner
TLLM_CHECK_WITH_INFO((mambaInProjSize > 0) == (mambaInnerSize > 0),
"mambaInProjSize and mambaInnerSize must both be zero or both be non-zero");
auto const mambaInProj = mambaInProjSize * tpSize;
auto const mambaInner = mambaInnerSize * tpSize;
// MoE latent projections are replicated (not TP-sharded), so moeLatentSize
// is the actual per-GPU dimension and should not be scaled by tpSize.
auto const moeLatent = moeLatentSize;
auto const numHeads = numAttentionHeads * tpSize;
auto const numKvHeads = numKvAttentionHeads * tpSize;
auto const attnHeadSize = attentionHeadSize;
Expand Down Expand Up @@ -74,6 +83,12 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
case ModuleType::kMLP_GATE_UP: modules.emplace_back(t, hidden, 2 * mlpHidden, false, true, -1, 0); break;
// Mamba modules: in_proj (hidden -> d_in_proj), out_proj (d_inner -> hidden)
case ModuleType::kMAMBA_IN_PROJ: modules.emplace_back(t, hidden, mambaInProj, false, true, -1, 0); break;
case ModuleType::kMAMBA_OUT_PROJ: modules.emplace_back(t, mambaInner, hidden, false, true, 1, -1); break;
// MoE latent projections: replicated (not TP-sharded), no TP split dims
case ModuleType::kMOE_LATENT_FC1: modules.emplace_back(t, hidden, moeLatent, false, true, -1, -1); break;
case ModuleType::kMOE_LATENT_FC2: modules.emplace_back(t, moeLatent, hidden, false, true, -1, -1); break;
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
}
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/runtime/loraUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void loraValidateRequestTensors(std::optional<std::uint64_t> const& optTaskId,
? config
: ITensor::view(config, ITensor::makeShape({config->getShape().d[1], config->getShape().d[2]}));

SizeType32 nbModelLayers = modelConfig.getNbAttentionLayers();
SizeType32 nbModelLayers = modelConfig.getNbLoraLayers();
TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(),
"Expected lora weights to be the same data type as base model");

Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,12 @@ def ceil_div(a, b):
attn_tp_size * attn_cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)

# For hybrid models (e.g., Nemotron-H with Mamba + Attention), LoRA can be applied
# to non-attention layers (e.g., Mamba in_proj/out_proj). Set num_lora_layers to
# total layers so the C++ LoRA validation accepts all layer indices.
if is_nemotron_hybrid(self.pretrained_config):
model_config_cpp.set_num_lora_layers(num_layers)

mlp_hidden_size = None
if self.pretrained_config.intermediate_size is not None:
mlp_hidden_size = ceil_div(self.pretrained_config.intermediate_size,
Expand Down
Loading
Loading