Skip to content

Commit a2bbc7f

Browse files
committed
[TRTLLM-10232][feat] Support LoRA adapter for nemotron-h models
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 94f9489 commit a2bbc7f

File tree

18 files changed

+495
-44
lines changed

18 files changed

+495
-44
lines changed

cpp/include/tensorrt_llm/runtime/loraModule.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class LoraModule
5353
kSHARED_EXPERT_H_TO_4H = 19,
5454
kSHARED_EXPERT_4H_TO_H = 20,
5555
kSHARED_EXPERT_GATE = 21,
56+
kMAMBA_IN_PROJ = 22,
57+
kMAMBA_OUT_PROJ = 23,
58+
kMOE_LATENT_UP = 24,
59+
kMOE_LATENT_DOWN = 25,
5660
};
5761

5862
explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
@@ -196,7 +200,7 @@ class LoraModule
196200
static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
197201
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
198202
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize = 0,
199-
SizeType32 moeHiddenSize = 0);
203+
SizeType32 moeHiddenSize = 0, SizeType32 mambaInProjSize = 0, SizeType32 mambaInnerSize = 0);
200204

201205
static ModuleType constexpr toModuleType(std::string_view const& name)
202206
{
@@ -244,6 +248,14 @@ class LoraModule
244248
return ModuleType::kSHARED_EXPERT_4H_TO_H;
245249
else if (name == "shared_expert_gate")
246250
return ModuleType::kSHARED_EXPERT_GATE;
251+
else if (name == "mamba_in_proj")
252+
return ModuleType::kMAMBA_IN_PROJ;
253+
else if (name == "mamba_out_proj")
254+
return ModuleType::kMAMBA_OUT_PROJ;
255+
else if (name == "moe_latent_up")
256+
return ModuleType::kMOE_LATENT_UP;
257+
else if (name == "moe_latent_down")
258+
return ModuleType::kMOE_LATENT_DOWN;
247259
else
248260
return ModuleType::kINVALID;
249261
}
@@ -274,6 +286,10 @@ class LoraModule
274286
case ModuleType::kSHARED_EXPERT_H_TO_4H: return "shared_expert_h_to_4h";
275287
case ModuleType::kSHARED_EXPERT_4H_TO_H: return "shared_expert_4h_to_h";
276288
case ModuleType::kSHARED_EXPERT_GATE: return "shared_expert_gate";
289+
case ModuleType::kMAMBA_IN_PROJ: return "mamba_in_proj";
290+
case ModuleType::kMAMBA_OUT_PROJ: return "mamba_out_proj";
291+
case ModuleType::kMOE_LATENT_UP: return "moe_latent_up";
292+
case ModuleType::kMOE_LATENT_DOWN: return "moe_latent_down";
277293
case ModuleType::kINVALID: return "INVALID";
278294
}
279295
return "INVALID";

cpp/include/tensorrt_llm/runtime/modelConfig.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,47 @@ class ModelConfig
228228
return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
229229
}
230230

231+
// Get the first LoRA layer index for a given PP rank.
232+
// Distributes extra layers to lower ranks when num_lora_layers is not evenly divisible by PP size.
233+
[[nodiscard]] SizeType32 getFirstLoraLayer(
234+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
235+
{
236+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
237+
if (mNbLoraLayers > 0)
238+
{
239+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
240+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
241+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
242+
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
243+
}
244+
// Fall back to attention layer distribution
245+
return countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
246+
}
247+
248+
// Get number of layers that can have LoRA applied for the given PP rank.
249+
// For hybrid models (e.g., Nemotron-H with Mamba + Attention), this may differ from num_attention_layers
250+
// because LoRA can be applied to non-attention layers (e.g., Mamba in_proj/out_proj).
251+
// Handles uneven PP splits by distributing extra layers to lower ranks.
252+
[[nodiscard]] SizeType32 getNbLoraLayers(
253+
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
254+
{
255+
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
256+
// If mNbLoraLayers is set (non-zero), use it with proper PP distribution
257+
if (mNbLoraLayers > 0)
258+
{
259+
auto const numBaseLayers = mNbLoraLayers / pipelineParallelism;
260+
auto const numExtraLayers = mNbLoraLayers % pipelineParallelism;
261+
// If num_lora_layers % pp_size = n != 0, first n ranks get one extra layer
262+
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
263+
}
264+
return getNbAttentionLayers(pipelineParallelism, pipelineParallelismRank);
265+
}
266+
267+
void setNbLoraLayers(SizeType32 nbLoraLayers)
268+
{
269+
mNbLoraLayers = nbLoraLayers;
270+
}
271+
231272
[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
232273
{
233274
return mNbHeads;
@@ -922,6 +963,8 @@ class ModelConfig
922963
std::vector<LoraModule> mLoraModules;
923964
SizeType32 mMlpHiddenSize;
924965
SizeType32 mMaxLoraRank;
966+
// Number of layers that can have LoRA applied (for hybrid models this may be > num_attention_layers)
967+
SizeType32 mNbLoraLayers{0};
925968

926969
std::optional<RnnConfig> mRnnConfig;
927970

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
220220
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP)
221221
.value("SHARED_EXPERT_H_TO_4H", tr::LoraModule::ModuleType::kSHARED_EXPERT_H_TO_4H)
222222
.value("SHARED_EXPERT_4H_TO_H", tr::LoraModule::ModuleType::kSHARED_EXPERT_4H_TO_H)
223-
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE);
223+
.value("SHARED_EXPERT_GATE", tr::LoraModule::ModuleType::kSHARED_EXPERT_GATE)
224+
.value("MAMBA_IN_PROJ", tr::LoraModule::ModuleType::kMAMBA_IN_PROJ)
225+
.value("MAMBA_OUT_PROJ", tr::LoraModule::ModuleType::kMAMBA_OUT_PROJ)
226+
.value("MOE_LATENT_UP", tr::LoraModule::ModuleType::kMOE_LATENT_UP)
227+
.value("MOE_LATENT_DOWN", tr::LoraModule::ModuleType::kMOE_LATENT_DOWN);
224228

225229
nb::class_<tr::LoraModule>(m, "LoraModule")
226230
.def(nb::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
@@ -236,7 +240,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
236240
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"),
237241
nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"),
238242
nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1,
239-
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0);
243+
nb::arg("num_experts") = 0, nb::arg("shared_expert_hidden_size") = 0, nb::arg("moe_hidden_size") = 0,
244+
nb::arg("mamba_in_proj_size") = 0, nb::arg("mamba_inner_size") = 0);
240245

241246
nb::class_<tc::QuantMode>(m, "QuantMode")
242247
.def_static("none", &tc::QuantMode::none)
@@ -342,6 +347,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
342347
.def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
343348
.def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
344349
.def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
350+
.def("num_lora_layers", &tr::ModelConfig::getNbLoraLayers, nb::arg("pipeline_parallelism") = 1,
351+
nb::arg("pipeline_parallelism_rank") = 0)
352+
.def("first_lora_layer", &tr::ModelConfig::getFirstLoraLayer, nb::arg("pipeline_parallelism") = 1,
353+
nb::arg("pipeline_parallelism_rank") = 0)
354+
.def("set_num_lora_layers", &tr::ModelConfig::setNbLoraLayers, nb::arg("num_lora_layers"))
345355
.def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
346356

347357
nb::class_<tr::WorldConfig>(m, "WorldConfig")

cpp/tensorrt_llm/runtime/loraCache.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const
454454
SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const
455455
{
456456
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
457-
auto const localNumLayers = mModelConfig.getNbAttentionLayers(
458-
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
457+
auto const localNumLayers
458+
= mModelConfig.getNbLoraLayers(mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
459459
auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers;
460460
auto const lastLayerId = firstLayerId + localNumLayers;
461461

@@ -579,8 +579,7 @@ std::vector<LoraCache::TaskLayerModuleConfig> LoraCache::copyToPages(TensorPtr s
579579
auto const tpRank = worldConfig.getTensorParallelRank();
580580
auto const ppSize = worldConfig.getPipelineParallelism();
581581
auto const ppRank = worldConfig.getPipelineParallelRank();
582-
// TODO(oargov): why *attention* layers?
583-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
582+
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
584583
auto const firstLayerId = ppRank * localNumLayers;
585584
auto const lastLayerId = firstLayerId + localNumLayers;
586585

cpp/tensorrt_llm/runtime/loraManager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
7272

7373
auto const ppSize = worldConfig.getPipelineParallelism();
7474
auto const ppRank = worldConfig.getPipelineParallelRank();
75-
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
75+
auto const localNumLayers = modelConfig.getNbLoraLayers(ppSize, ppRank);
7676
auto const firstLayerId = ppRank * localNumLayers;
7777

7878
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
@@ -124,7 +124,7 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
124124
{
125125
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
126126
auto localNbLayers
127-
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
127+
= modelConfig.getNbLoraLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
128128
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
129129

130130
for (auto const& [modId, mod] : mModuleIdToModule)

cpp/tensorrt_llm/runtime/loraModule.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ namespace tensorrt_llm::runtime
2222
std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
2323
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
2424
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts, SizeType32 sharedExpertHiddenSize,
25-
SizeType32 moeHiddenSize)
25+
SizeType32 moeHiddenSize, SizeType32 mambaInProjSize, SizeType32 mambaInnerSize)
2626
{
2727
auto const hidden = hiddenSize * tpSize;
2828
auto const mlpHidden = mlpHiddenSize * tpSize;
2929
auto const sharedExpertHidden = sharedExpertHiddenSize > 0 ? sharedExpertHiddenSize * tpSize : mlpHidden;
3030
auto const moeHidden = moeHiddenSize > 0 ? moeHiddenSize * tpSize : mlpHidden;
31+
// Mamba dimensions: in_proj outputs d_in_proj, out_proj inputs d_inner
32+
// Fall back to mlpHidden if not specified (for backward compatibility)
33+
auto const mambaInProj = mambaInProjSize > 0 ? mambaInProjSize * tpSize : mlpHidden;
34+
auto const mambaInner = mambaInnerSize > 0 ? mambaInnerSize * tpSize : mlpHidden;
3135
auto const numHeads = numAttentionHeads * tpSize;
3236
auto const numKvHeads = numKvAttentionHeads * tpSize;
3337
auto const attnHeadSize = attentionHeadSize;
@@ -74,6 +78,12 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
7478
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
7579
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
7680
case ModuleType::kMLP_GATE_UP: modules.emplace_back(t, hidden, 2 * mlpHidden, false, true, -1, 0); break;
81+
// Mamba modules: in_proj (hidden -> d_in_proj), out_proj (d_inner -> hidden)
82+
case ModuleType::kMAMBA_IN_PROJ: modules.emplace_back(t, hidden, mambaInProj, false, true, -1, 0); break;
83+
case ModuleType::kMAMBA_OUT_PROJ: modules.emplace_back(t, mambaInner, hidden, false, true, 1, -1); break;
84+
// MoE latent projections: up expands to moe_hidden, down contracts back
85+
case ModuleType::kMOE_LATENT_UP: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
86+
case ModuleType::kMOE_LATENT_DOWN: modules.emplace_back(t, mlpHidden, hidden, false, true, 1, -1); break;
7787
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
7888
}
7989
}

cpp/tensorrt_llm/runtime/loraUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void loraValidateRequestTensors(std::optional<std::uint64_t> const& optTaskId,
8484
? config
8585
: ITensor::view(config, ITensor::makeShape({config->getShape().d[1], config->getShape().d[2]}));
8686

87-
SizeType32 nbModelLayers = modelConfig.getNbAttentionLayers();
87+
SizeType32 nbModelLayers = modelConfig.getNbLoraLayers();
8888
TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(),
8989
"Expected lora weights to be the same data type as base model");
9090

examples/llm-api/quickstart_advanced.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import time
44

55
from tensorrt_llm import LLM, SamplingParams
6+
from tensorrt_llm.executor.request import LoRARequest
67
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
78
CudaGraphConfig, DraftTargetDecodingConfig,
89
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
910
MTPDecodingConfig, NGramDecodingConfig,
1011
TorchCompileConfig)
12+
from tensorrt_llm.lora_helper import LoraConfig
1113

1214
example_prompts = [
1315
"Hello, my name is",
@@ -198,6 +200,18 @@ def add_llm_args(parser):
198200
parser.add_argument('--relaxed_topk', type=int, default=1)
199201
parser.add_argument('--relaxed_delta', type=float, default=0.)
200202

203+
# LoRA
204+
parser.add_argument('--lora_dir',
205+
type=str,
206+
default=None,
207+
help='Path to LoRA adapter directory.')
208+
parser.add_argument(
209+
'--max_lora_rank',
210+
type=int,
211+
default=None,
212+
help='Maximum LoRA rank. If not specified, inferred from adapter config.'
213+
)
214+
201215
# HF
202216
parser.add_argument('--trust_remote_code',
203217
default=False,
@@ -292,6 +306,18 @@ def setup_llm(args, **kwargs):
292306
batching_wait_iters=args.attention_dp_batching_wait_iters,
293307
)
294308

309+
lora_config = None
310+
lora_request = None
311+
if args.lora_dir:
312+
max_lora_rank = args.max_lora_rank if args.max_lora_rank is not None else 64
313+
lora_config = LoraConfig(lora_dir=[args.lora_dir],
314+
max_lora_rank=max_lora_rank)
315+
lora_request = LoRARequest(
316+
lora_name="lora_adapter",
317+
lora_int_id=0, # First adapter ID
318+
lora_path=args.lora_dir,
319+
)
320+
295321
llm = LLM(
296322
model=args.model_dir,
297323
backend='pytorch',
@@ -327,6 +353,7 @@ def setup_llm(args, **kwargs):
327353
gather_generation_logits=args.return_generation_logits,
328354
max_beam_width=args.max_beam_width,
329355
orchestrator_type=args.orchestrator_type,
356+
lora_config=lora_config,
330357
**kwargs)
331358

332359
use_beam_search = args.max_beam_width > 1
@@ -352,14 +379,14 @@ def setup_llm(args, **kwargs):
352379
use_beam_search=use_beam_search,
353380
additional_model_outputs=args.additional_model_outputs,
354381
)
355-
return llm, sampling_params
382+
return llm, sampling_params, lora_request
356383

357384

358385
def main():
359386
args = parse_arguments()
360387
prompts = args.prompt if args.prompt else example_prompts
361388

362-
llm, sampling_params = setup_llm(args)
389+
llm, sampling_params, lora_request = setup_llm(args)
363390
new_prompts = []
364391
if args.apply_chat_template:
365392
for prompt in prompts:
@@ -369,7 +396,7 @@ def main():
369396
tokenize=False,
370397
add_generation_prompt=True))
371398
prompts = new_prompts
372-
outputs = llm.generate(prompts, sampling_params)
399+
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
373400

374401
for i, output in enumerate(outputs):
375402
prompt = output.prompt

tensorrt_llm/_torch/model_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,12 @@ def get_bindings_model_config(self,
686686
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
687687
model_config_cpp.set_num_kv_heads(num_kv_heads)
688688

689+
# For hybrid models (e.g., Nemotron-H with Mamba + Attention), LoRA can be applied
690+
# to non-attention layers (e.g., Mamba in_proj/out_proj). Set num_lora_layers to
691+
# total layers so the C++ LoRA validation accepts all layer indices.
692+
if is_nemotron_hybrid(self.pretrained_config):
693+
model_config_cpp.set_num_lora_layers(num_layers)
694+
689695
mlp_hidden_size = None
690696
if self.pretrained_config.intermediate_size is not None:
691697
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ConsumableWeightsDict
4444
from tensorrt_llm._utils import get_sm_version
4545
from tensorrt_llm.functional import PositionEmbeddingType
46+
from tensorrt_llm.lora_helper import LoraConfig
4647
from tensorrt_llm.mapping import Mapping
4748
from tensorrt_llm.models.modeling_utils import QuantConfig
4849
from tensorrt_llm.quantization.mode import QuantAlgo
@@ -852,8 +853,11 @@ def __init__(
852853
fuse_routing_kernel: bool = True,
853854
apply_routing: bool = False,
854855
moe_backend: str = 'CUTLASS',
856+
lora_config: Optional[LoraConfig] = None,
855857
):
856858
super().__init__()
859+
self.hidden_size = hidden_size
860+
self.num_experts = num_experts
857861
self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
858862
dtype=dtype),
859863
requires_grad=False)
@@ -877,11 +881,27 @@ def __init__(
877881
routed_scaling_factor=routed_scaling_factor,
878882
is_fused=fuse_routing_kernel)
879883

880-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
884+
# LoRA for gate (router) - only create when LoRA is configured
885+
from ..peft.lora.layer import LoraModuleType
886+
self.gate_lora = (LoraLayer([LoraModuleType.MOE_ROUTER], [num_experts])
887+
if lora_config is not None else None)
888+
889+
def forward(
890+
self,
891+
hidden_states: torch.Tensor,
892+
lora_params: Optional[dict] = None,
893+
layer_idx: Optional[int] = None,
894+
) -> torch.Tensor:
881895
logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states,
882896
self.weight.t(),
883897
bias=None,
884898
out_dtype=torch.float32)
899+
# Apply LoRA to gate (if LoRA is configured and weights are loaded)
900+
if self.gate_lora is not None and bool(
901+
lora_params) and layer_idx is not None:
902+
lora_output = self.gate_lora(hidden_states, lora_params, layer_idx)
903+
if lora_output is not None:
904+
logits = logits + lora_output.to(logits.dtype)
885905
return logits
886906

887907
def load_weights(self, weights: List[Dict]):

0 commit comments

Comments
 (0)