33import time
44
55from tensorrt_llm import LLM , SamplingParams
6+ from tensorrt_llm .executor .request import LoRARequest
67from tensorrt_llm .llmapi import (AttentionDpConfig , AutoDecodingConfig ,
78 CudaGraphConfig , DraftTargetDecodingConfig ,
89 Eagle3DecodingConfig , KvCacheConfig , MoeConfig ,
910 MTPDecodingConfig , NGramDecodingConfig ,
1011 TorchCompileConfig )
12+ from tensorrt_llm .lora_helper import LoraConfig
1113
1214example_prompts = [
1315 "Hello, my name is" ,
@@ -198,6 +200,18 @@ def add_llm_args(parser):
198200 parser .add_argument ('--relaxed_topk' , type = int , default = 1 )
199201 parser .add_argument ('--relaxed_delta' , type = float , default = 0. )
200202
203+ # LoRA
204+ parser .add_argument ('--lora_dir' ,
205+ type = str ,
206+ default = None ,
207+ help = 'Path to LoRA adapter directory.' )
208+ parser .add_argument (
209+ '--max_lora_rank' ,
210+ type = int ,
211+ default = None ,
212+ help = 'Maximum LoRA rank. If not specified, inferred from adapter config.'
213+ )
214+
201215 # HF
202216 parser .add_argument ('--trust_remote_code' ,
203217 default = False ,
@@ -292,6 +306,18 @@ def setup_llm(args, **kwargs):
292306 batching_wait_iters = args .attention_dp_batching_wait_iters ,
293307 )
294308
309+ lora_config = None
310+ lora_request = None
311+ if args .lora_dir :
312+ max_lora_rank = args .max_lora_rank if args .max_lora_rank is not None else 64
313+ lora_config = LoraConfig (lora_dir = [args .lora_dir ],
314+ max_lora_rank = max_lora_rank )
315+ lora_request = LoRARequest (
316+ lora_name = "lora_adapter" ,
317+ lora_int_id = 0 , # First adapter ID
318+ lora_path = args .lora_dir ,
319+ )
320+
295321 llm = LLM (
296322 model = args .model_dir ,
297323 backend = 'pytorch' ,
@@ -327,6 +353,7 @@ def setup_llm(args, **kwargs):
327353 gather_generation_logits = args .return_generation_logits ,
328354 max_beam_width = args .max_beam_width ,
329355 orchestrator_type = args .orchestrator_type ,
356+ lora_config = lora_config ,
330357 ** kwargs )
331358
332359 use_beam_search = args .max_beam_width > 1
@@ -352,14 +379,14 @@ def setup_llm(args, **kwargs):
352379 use_beam_search = use_beam_search ,
353380 additional_model_outputs = args .additional_model_outputs ,
354381 )
355- return llm , sampling_params
382+ return llm , sampling_params , lora_request
356383
357384
358385def main ():
359386 args = parse_arguments ()
360387 prompts = args .prompt if args .prompt else example_prompts
361388
362- llm , sampling_params = setup_llm (args )
389+ llm , sampling_params , lora_request = setup_llm (args )
363390 new_prompts = []
364391 if args .apply_chat_template :
365392 for prompt in prompts :
@@ -369,7 +396,7 @@ def main():
369396 tokenize = False ,
370397 add_generation_prompt = True ))
371398 prompts = new_prompts
372- outputs = llm .generate (prompts , sampling_params )
399+ outputs = llm .generate (prompts , sampling_params , lora_request = lora_request )
373400
374401 for i , output in enumerate (outputs ):
375402 prompt = output .prompt
0 commit comments