@@ -220,7 +220,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
220220 .value (" MLP_GATE_UP" , tr::LoraModule::ModuleType::kMLP_GATE_UP )
221221 .value (" SHARED_EXPERT_H_TO_4H" , tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H )
222222 .value (" SHARED_EXPERT_4H_TO_H" , tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H )
223- .value (" SHARED_EXPERT_GATE" , tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE );
223+ .value (" SHARED_EXPERT_GATE" , tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE )
224+ .value (" MAMBA_IN_PROJ" , tr::LoraModule::ModuleType::kMAMBA_IN_PROJ )
225+ .value (" MAMBA_OUT_PROJ" , tr::LoraModule::ModuleType::kMAMBA_OUT_PROJ )
226+ .value (" MOE_LATENT_UP" , tr::LoraModule::ModuleType::kMOE_LATENT_UP )
227+ .value (" MOE_LATENT_DOWN" , tr::LoraModule::ModuleType::kMOE_LATENT_DOWN );
224228
225229 nb::class_<tr::LoraModule>(m, " LoraModule" )
226230 .def (nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool , bool , SizeType32, SizeType32>(),
@@ -236,7 +240,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
236240 .def_static (" create_lora_modules" , &tr::LoraModule::createLoraModules, nb::arg (" lora_module_names" ),
237241 nb::arg (" hidden_size" ), nb::arg (" mlp_hidden_size" ), nb::arg (" num_attention_heads" ),
238242 nb::arg (" num_kv_attention_heads" ), nb::arg (" attention_head_size" ), nb::arg (" tp_size" ) = 1 ,
239- nb::arg (" num_experts" ) = 0 , nb::arg (" shared_expert_hidden_size" ) = 0 , nb::arg (" moe_hidden_size" ) = 0 );
243+ nb::arg (" num_experts" ) = 0 , nb::arg (" shared_expert_hidden_size" ) = 0 , nb::arg (" moe_hidden_size" ) = 0 ,
244+ nb::arg (" mamba_in_proj_size" ) = 0 , nb::arg (" mamba_inner_size" ) = 0 );
240245
241246 nb::class_<tc::QuantMode>(m, " QuantMode" )
242247 .def_static (" none" , &tc::QuantMode::none)
@@ -342,6 +347,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
342347 .def_prop_rw (" lora_modules" , &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
343348 .def_prop_rw (" max_lora_rank" , &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
344349 .def_prop_rw (" mlp_hidden_size" , &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
350+ .def (" num_lora_layers" , &tr::ModelConfig::getNbLoraLayers, nb::arg (" pipeline_parallelism" ) = 1 ,
351+ nb::arg (" pipeline_parallelism_rank" ) = 0 )
352+ .def (" first_lora_layer" , &tr::ModelConfig::getFirstLoraLayer, nb::arg (" pipeline_parallelism" ) = 1 ,
353+ nb::arg (" pipeline_parallelism_rank" ) = 0 )
354+ .def (" set_num_lora_layers" , &tr::ModelConfig::setNbLoraLayers, nb::arg (" num_lora_layers" ))
345355 .def_prop_rw (" size_per_head" , &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
346356
347357 nb::class_<tr::WorldConfig>(m, " WorldConfig" )
0 commit comments