Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data)
int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, nvinfer1::DataType& type, std::vector<size_t>& shapeVec)
{
char* header_c = (char*) malloc(header_len * sizeof(char));
TLLM_CHECK_WITH_INFO(header_c != nullptr, "Failed to allocate memory for npy header");
size_t n_elems = fread((void*) header_c, sizeof(char), header_len, f_ptr);
if (n_elems != header_len)
{
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def create_mpi_comm_session(
logger_debug(
f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n",
"yellow")
get_spawn_proxy_process_ipc_hmac_key_env()
hmac_key = get_spawn_proxy_process_ipc_hmac_key_env()
return RemoteMpiCommSessionClient(
addr=get_spawn_proxy_process_ipc_addr_env())
addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key)
else:
logger_debug(
f"Using MpiCommSession to bind to external MPI processes\n",
Expand Down
10 changes: 7 additions & 3 deletions tensorrt_llm/llmapi/mgmn_leader_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from tensorrt_llm._utils import global_mpi_rank, mpi_world_size
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env
from tensorrt_llm.executor.utils import (
get_spawn_proxy_process_ipc_addr_env,
get_spawn_proxy_process_ipc_hmac_key_env)
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer
from tensorrt_llm.llmapi.utils import logger_debug

Expand All @@ -23,6 +25,7 @@ def launch_server_main(sub_comm=None):
comm=sub_comm,
n_workers=num_ranks,
addr=get_spawn_proxy_process_ipc_addr_env(),
hmac_key=get_spawn_proxy_process_ipc_hmac_key_env(),
is_comm=True)
logger_debug(
f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}")
Expand All @@ -32,8 +35,9 @@ def launch_server_main(sub_comm=None):


def stop_server_main():
queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), None),
use_hmac_encryption=False,
hmac_key = get_spawn_proxy_process_ipc_hmac_key_env()
queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), hmac_key),
use_hmac_encryption=bool(hmac_key),
is_server=False,
socket_type=zmq.PAIR)

Expand Down
31 changes: 29 additions & 2 deletions tensorrt_llm/llmapi/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import pickle # nosec B403
from typing import Optional

import torch

from tensorrt_llm import serialization
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
Expand Down Expand Up @@ -62,8 +62,35 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
# using restricted unpickler from tensorrt_llm.serialization
logger.info("Deserializing base64-encoded weight handles")
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
decoded_data = base64.b64decode(serialized_handles)
# Allow basic builtins and all torch modules
approved_imports = {
"builtins": [
"list",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"dict",
"NoneType",
"type",
],
}
all_handles = serialization.loads(
decoded_data,
approved_imports=approved_imports,
approved_module_patterns=[r"^torch.*"],
)

# Verify the result is a list as expected
if not isinstance(all_handles, list):
raise ValueError(
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
)
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles
Expand Down
41 changes: 29 additions & 12 deletions tensorrt_llm/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
import re
from functools import partial

# This is an example class (white list) to showcase how to guard serialization with approved classes.
Expand Down Expand Up @@ -126,19 +127,31 @@ def register_approved_class(obj):

class Unpickler(pickle.Unpickler):

def __init__(self, *args, approved_imports={}, **kwargs):
def __init__(self,
*args,
approved_imports={},
approved_module_patterns=None,
**kwargs):
super().__init__(*args, **kwargs)
self.approved_imports = approved_imports
self.approved_module_patterns = approved_module_patterns or []

# only import approved classes, this is the security boundary.
def find_class(self, module, name):
if name not in self.approved_imports.get(module, []):
# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")
return super().find_class(module, name)
# Check exact match in approved_imports
if name in self.approved_imports.get(module, []):
return super().find_class(module, name)

# Check regex pattern match in approved_module_patterns
for pattern in self.approved_module_patterns:
if re.match(pattern, module):
return super().find_class(module, name)

# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")


# these are taken from the pickle module to allow for this to be a drop in replacement
Expand All @@ -156,13 +169,15 @@ def load(file,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
approved_imports={},
approved_module_patterns=None):
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
approved_imports=approved_imports,
approved_module_patterns=approved_module_patterns).load()


def loads(s,
Expand All @@ -172,7 +187,8 @@ def loads(s,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
approved_imports={},
approved_module_patterns=None):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
Expand All @@ -181,4 +197,5 @@ def loads(s,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
approved_imports=approved_imports,
approved_module_patterns=approved_module_patterns).load()
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import pickle
from typing import Callable, List, Optional

import pytest
Expand Down Expand Up @@ -71,6 +73,40 @@ def get_weight_ipc_handles(

return ret

def get_weight_ipc_handles_serialized(
self,
cuda_device: Optional[List[int]] = None,
weight_filter: Optional[Callable[[str], bool]] = None,
):
"""
Get base64-encoded serialized IPC handles for model weights.
Args:
cuda_device: List of CUDA device indices to get weights from
weight_filter: Optional function that takes weight name and returns True if weight should be included
Returns:
ret: Dictionary mapping device UUIDs to base64-encoded pickled handles
"""
ret = {}
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device

for device in device_list:
all_handles = []
for item in self.all_weights[device]:
name, p = item
# Apply filter if provided
if weight_filter is not None and not weight_filter(name):
continue
handle = reduce_tensor(p)
all_handles.append((name, handle))

# Serialize with base64-encoded pickle
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
ret[self.device_uuid[device]] = serialized

return ret

def generate_batch_incremental(
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
):
Expand Down Expand Up @@ -153,11 +189,13 @@ def run_generate(llm, hf_model, prompts, sampling_params):
return llm_logits, ref_logits


@pytest.mark.parametrize("use_serialized_handles", [True, False])
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
)
def test_llm_update_weights(model_dir):
def test_llm_update_weights(model_dir, use_serialized_handles):
"""Test LLM update_weights with both serialized and direct IPC handle formats."""
model_dir = str(llm_models_root() / model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)

Expand All @@ -182,7 +220,11 @@ def test_llm_update_weights(model_dir):

sampling_params = SamplingParams(temperature=0, return_generation_logits=True)

ipc_handles = hf_model.get_weight_ipc_handles([0])
# Get IPC handles in either serialized or direct format
if use_serialized_handles:
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
else:
ipc_handles = hf_model.get_weight_ipc_handles([0])

llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
Expand Down
4 changes: 3 additions & 1 deletion tests/unittest/llmapi/_run_mpi_comm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click

from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_hmac_key_env
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored

Expand All @@ -15,8 +16,9 @@ def main(task_type: Literal["submit", "submit_sync"]):
tasks = [0]
assert os.environ[
'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
hmac_key = get_spawn_proxy_process_ipc_hmac_key_env()
client = RemoteMpiCommSessionClient(
os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'])
os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'], hmac_key=hmac_key)
for task in tasks:
if task_type == "submit":
client.submit(print_colored, f"{task}\n", "green")
Expand Down
7 changes: 5 additions & 2 deletions tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import click

from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.executor.utils import (
LlmLauncherEnvs, get_spawn_proxy_process_ipc_hmac_key_env)
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored

Expand All @@ -13,8 +14,10 @@ def run_task(task_type: Literal["submit", "submit_sync"]):
assert os.environ[
LlmLauncherEnvs.
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
hmac_key = get_spawn_proxy_process_ipc_hmac_key_env()
client = RemoteMpiCommSessionClient(
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR])
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR],
hmac_key=hmac_key)

for task in tasks:
if task_type == "submit":
Expand Down