Skip to content

Commit 6428102

Browse files
authored
[None][fix] Enable LoRA in EAGLE3 speculative decoding (NVIDIA#13005)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent 098495e commit 6428102

5 files changed

Lines changed: 119 additions & 8 deletions

File tree

tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,13 @@ def get_or_assign_task(self, task_id: int) -> tuple[int, Optional[int]]:
7171
self.task2slot[task_id] = evicted_slot
7272
return self.task2slot[task_id], evicted_task
7373

74-
def remove_evicted_slots_in_cpp(self, peft_cache_manager: PeftCacheManager):
74+
def remove_evicted_slots_in_cpp(self, peft_cache_manager: Optional[PeftCacheManager]):
7575
"""
7676
Validate slots by removing tasks that are not cached in PeftCacheManager.
7777
"""
78+
if peft_cache_manager is None:
79+
return
80+
7881
for task_id in self.slot2task:
7982
if task_id is not None:
8083
if not peft_cache_manager.is_task_cached_device(task_id):

tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,17 @@ def update_sorted_indices(self, slot_ids: List[int], tokens_per_seq: int = 1):
213213
self.sorted_ids[:num_tokens].copy_(sorted_ids_host, non_blocking=True)
214214

215215
def update_weight_pointers(
216-
self, peft_table: Dict[int, List], slot_to_task_mapping: tuple[Optional[int], ...]
216+
self,
217+
peft_table: Optional[Dict[int, List]],
218+
slot_to_task_mapping: tuple[Optional[int], ...],
217219
):
218220
"""
219221
Update weight pointers from PEFT cache manager.
220222
221223
Args:
222224
peft_table: PEFT table from cache manager containing weight pointers, map task id to list of layer
223-
module configs
225+
module configs. Can be None when slot membership changes without any newly prepared PEFT
226+
entries in the current batch.
224227
slot_to_task_mapping: Mapping from slot_id to task_id, tuple of None for empty slots
225228
"""
226229

@@ -241,9 +244,9 @@ def zero_out_weight_pointers(slot_id: int):
241244
if task_id is None: # empty slot
242245
self.slot_ranks_host[slot_id] = 0
243246
zero_out_weight_pointers(slot_id)
244-
elif (
245-
task_id not in peft_table
246-
): # task has not changed in the slot, retain old rank / weight pointers
247+
elif peft_table is None or task_id not in peft_table:
248+
# No new PEFT entry was prepared for this task in the current batch, so retain
249+
# the existing rank and weight pointers for the occupied slot.
247250
continue
248251
else: # task might have changed in the slot, update its rank
249252
task_configs = peft_table[task_id]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ l0_h100:
1919
- unittest/_torch/compilation
2020
- unittest/_torch/debugger
2121
- unittest/_torch/executor
22+
- unittest/_torch/lora
2223
- unittest/_torch/misc
2324
# ------------- modules (non-MoE) ---------------
2425
- unittest/_torch/modules/test_mla_helix.py
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
3+
from tensorrt_llm._torch.peft.lora.adapter_slot_manager import AdapterSlotManager
4+
from tensorrt_llm._torch.peft.lora.cuda_graph_lora_params import CudaGraphLoraParams
5+
6+
7+
def test_cuda_graph_lora_params_handle_missing_peft_table():
8+
layer_key = CudaGraphLoraParams.LoraLayerKey(layer_idx=0, module_ids=(1, 2))
9+
layer_info = {layer_key: CudaGraphLoraParams.LoraLayerInfo(module_num=2, output_sizes=[16, 32])}
10+
params = CudaGraphLoraParams(
11+
max_batch_size=2, max_lora_size=2, max_rank=8, layer_info=layer_info
12+
)
13+
layer_params = params.layer_params[layer_key]
14+
15+
layer_params.h_b_ptrs[:, 0] = torch.tensor([11, 22], dtype=torch.int64)
16+
layer_params.h_b_prime_ptrs[:, 0] = torch.tensor([33, 44], dtype=torch.int64)
17+
layer_params.h_b_ptrs[:, 1] = torch.tensor([55, 66], dtype=torch.int64)
18+
layer_params.h_b_prime_ptrs[:, 1] = torch.tensor([77, 88], dtype=torch.int64)
19+
params.slot_ranks_host[:] = torch.tensor([4, 7], dtype=torch.int32)
20+
21+
params.update_weight_pointers(None, (123, None))
22+
23+
assert params.slot_ranks_host.tolist() == [4, 0]
24+
assert layer_params.h_b_ptrs[:, 0].tolist() == [11, 22]
25+
assert layer_params.h_b_prime_ptrs[:, 0].tolist() == [33, 44]
26+
assert layer_params.h_b_ptrs[:, 1].tolist() == [0, 0]
27+
assert layer_params.h_b_prime_ptrs[:, 1].tolist() == [0, 0]
28+
29+
30+
def test_adapter_slot_manager_handles_missing_peft_cache_manager():
31+
manager = AdapterSlotManager(max_num_adapters=2)
32+
manager.slot2task[0] = 123
33+
manager.task2slot[123] = 0
34+
35+
manager.remove_evicted_slots_in_cpp(None)
36+
37+
assert manager.get_slot_to_task_mapping() == (123, None)
38+
assert manager.task2slot[123] == 0

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from tensorrt_llm import LLM, SamplingParams
1515
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
1616
from tensorrt_llm._torch.metadata import KVCacheParams
17+
from tensorrt_llm.executor.request import LoRARequest
1718
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
1819
KvCacheConfig)
20+
from tensorrt_llm.lora_helper import LoraConfig
1921

2022
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
2123

@@ -756,8 +758,9 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
756758
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
757759

758760
prompts = [
759-
"The capital of France is", "The president of the United States is",
760-
"The future of AI is"
761+
"The capital of France is",
762+
"The president of the United States is",
763+
"The future of AI is",
761764
]
762765

763766
sampling_params = SamplingParams(max_tokens=2048, temperature=0)
@@ -815,5 +818,68 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
815818
llm_spec.shutdown()
816819

817820

821+
@pytest.mark.parametrize("use_cuda_graph", [True, False])
822+
def test_eagle3_lora(use_cuda_graph: bool):
823+
"""Test LoRA with 3 requests and max_batch_size=4.
824+
825+
This test verifies that when using LoRA modules,
826+
the system properly applies the LoRA configurations.
827+
"""
828+
attn_backend = "TRTLLM"
829+
enable_block_reuse = False
830+
use_one_model = True
831+
enable_chunked_prefill = False
832+
833+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
834+
if total_mem_gb < 35:
835+
pytest.skip("Not enough memory to load target + draft model")
836+
837+
models_path = llm_models_root()
838+
839+
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
840+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
841+
hf_lora_dir = f"{models_path}/llama-models/luotuo-lora-7b-0.1"
842+
843+
# Test with 3 requests and max_batch_size=4 to trigger padding
844+
max_batch_size = 4
845+
max_draft_len = 4
846+
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
847+
max_tokens=8192)
848+
cuda_graph_config = CudaGraphConfig(
849+
batch_sizes=[1, 2, 4], enable_padding=True) if use_cuda_graph else None
850+
lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2)
851+
852+
llm_common_config = dict(
853+
model=target_model_dir,
854+
attn_backend=attn_backend,
855+
cuda_graph_config=cuda_graph_config,
856+
max_batch_size=max_batch_size,
857+
kv_cache_config=kv_cache_config,
858+
max_seq_len=1024,
859+
enable_chunked_prefill=enable_chunked_prefill,
860+
lora_config=lora_config,
861+
)
862+
863+
spec_config = Eagle3DecodingConfig(
864+
max_draft_len=max_draft_len,
865+
speculative_model=eagle_model_dir,
866+
eagle3_one_model=use_one_model,
867+
)
868+
869+
# Create the LLM instance
870+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
871+
872+
prompts = [
873+
"The capital of France is",
874+
"The president of the United States is",
875+
"The future of AI is",
876+
]
877+
lora_requests = [LoRARequest("luotuo", 1, hf_lora_dir)] * len(prompts)
878+
879+
sampling_params = SamplingParams(max_tokens=20, temperature=0)
880+
llm_spec.generate(prompts, sampling_params, lora_request=lora_requests)
881+
llm_spec.shutdown()
882+
883+
818884
if __name__ == "__main__":
819885
unittest.main()

0 commit comments

Comments
 (0)