diff --git a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp index 4c37f480bd4..6f95704455d 100644 --- a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp +++ b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp @@ -105,6 +105,7 @@ void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data) int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, nvinfer1::DataType& type, std::vector& shapeVec) { char* header_c = (char*) malloc(header_len * sizeof(char)); + TLLM_CHECK_WITH_INFO(header_c != nullptr, "Failed to allocate memory for npy header"); size_t n_elems = fread((void*) header_c, sizeof(char), header_len, f_ptr); if (n_elems != header_len) { diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index e52ea481fb0..9801ee1cd23 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -55,9 +55,9 @@ def create_mpi_comm_session( logger_debug( f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n", "yellow") - get_spawn_proxy_process_ipc_hmac_key_env() + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() return RemoteMpiCommSessionClient( - addr=get_spawn_proxy_process_ipc_addr_env()) + addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key) else: logger_debug( f"Using MpiCommSession to bind to external MPI processes\n", diff --git a/tensorrt_llm/llmapi/mgmn_leader_node.py b/tensorrt_llm/llmapi/mgmn_leader_node.py index 85f8561ebe4..2b1d11b0ccf 100644 --- a/tensorrt_llm/llmapi/mgmn_leader_node.py +++ b/tensorrt_llm/llmapi/mgmn_leader_node.py @@ -9,7 +9,9 @@ from tensorrt_llm._utils import global_mpi_rank, mpi_world_size from tensorrt_llm.executor.ipc import ZeroMqQueue -from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env +from tensorrt_llm.executor.utils import ( + get_spawn_proxy_process_ipc_addr_env, + get_spawn_proxy_process_ipc_hmac_key_env) from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer from tensorrt_llm.llmapi.utils import logger_debug @@ -23,6 +25,7 @@ def launch_server_main(sub_comm=None): comm=sub_comm, n_workers=num_ranks, addr=get_spawn_proxy_process_ipc_addr_env(), + hmac_key=get_spawn_proxy_process_ipc_hmac_key_env(), is_comm=True) logger_debug( f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}") @@ -32,8 +35,9 @@ def launch_server_main(sub_comm=None): def stop_server_main(): - queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), None), - use_hmac_encryption=False, + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() + queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), hmac_key), + use_hmac_encryption=bool(hmac_key), is_server=False, socket_type=zmq.PAIR) diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index ce6eaa5b4ff..52219f99852 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -1,9 +1,9 @@ import base64 -import pickle # nosec B403 from typing import Optional import torch +from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid @@ -62,8 +62,35 @@ def update_weights(self, ipc_handles: Optional[dict] = None): serialized_handles = ipc_handles[device_uuid] if isinstance(serialized_handles, str): # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization logger.info("Deserializing base64-encoded weight handles") - all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301 + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and all torch modules + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + approved_module_patterns=[r"^torch.*"], + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError( + f"Deserialized data must be a list, got {type(all_handles).__name__} instead" + ) else: # Data is already in the correct format (backward compatibility) all_handles = serialized_handles diff --git a/tensorrt_llm/serialization.py b/tensorrt_llm/serialization.py index 4045df36177..ac303fec9f8 100644 --- a/tensorrt_llm/serialization.py +++ b/tensorrt_llm/serialization.py @@ -2,6 +2,7 @@ # pickle is not secure, but but this whole file is a wrapper to make it # possible to mitigate the primary risk of code injection via pickle. import pickle # nosec B403 +import re from functools import partial # This is an example class (white list) to showcase how to guard serialization with approved classes. @@ -126,19 +127,31 @@ def register_approved_class(obj): class Unpickler(pickle.Unpickler): - def __init__(self, *args, approved_imports={}, **kwargs): + def __init__(self, + *args, + approved_imports={}, + approved_module_patterns=None, + **kwargs): super().__init__(*args, **kwargs) self.approved_imports = approved_imports + self.approved_module_patterns = approved_module_patterns or [] # only import approved classes, this is the security boundary. def find_class(self, module, name): - if name not in self.approved_imports.get(module, []): - # If this is triggered when it shouldn't be, then the module - # and class should be added to the approved_imports. If the class - # is being used as part of a routine scenario, then it should be added - # to the appropriate base classes above. - raise ValueError(f"Import {module} | {name} is not allowed") - return super().find_class(module, name) + # Check exact match in approved_imports + if name in self.approved_imports.get(module, []): + return super().find_class(module, name) + + # Check regex pattern match in approved_module_patterns + for pattern in self.approved_module_patterns: + if re.match(pattern, module): + return super().find_class(module, name) + + # If this is triggered when it shouldn't be, then the module + # and class should be added to the approved_imports. If the class + # is being used as part of a routine scenario, then it should be added + # to the appropriate base classes above. + raise ValueError(f"Import {module} | {name} is not allowed") # these are taken from the pickle module to allow for this to be a drop in replacement @@ -156,13 +169,15 @@ def load(file, encoding="ASCII", errors="strict", buffers=None, - approved_imports={}): + approved_imports={}, + approved_module_patterns=None): return Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors, - approved_imports=approved_imports).load() + approved_imports=approved_imports, + approved_module_patterns=approved_module_patterns).load() def loads(s, @@ -172,7 +187,8 @@ def loads(s, encoding="ASCII", errors="strict", buffers=None, - approved_imports={}): + approved_imports={}, + approved_module_patterns=None): if isinstance(s, str): raise TypeError("Can't load pickle from unicode string") file = io.BytesIO(s) @@ -181,4 +197,5 @@ def loads(s, buffers=buffers, encoding=encoding, errors=errors, - approved_imports=approved_imports).load() + approved_imports=approved_imports, + approved_module_patterns=approved_module_patterns).load() diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 96e88226122..9914913c2ff 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -1,3 +1,5 @@ +import base64 +import pickle from typing import Callable, List, Optional import pytest @@ -71,6 +73,40 @@ def get_weight_ipc_handles( return ret + def get_weight_ipc_handles_serialized( + self, + cuda_device: Optional[List[int]] = None, + weight_filter: Optional[Callable[[str], bool]] = None, + ): + """ + Get base64-encoded serialized IPC handles for model weights. + + Args: + cuda_device: List of CUDA device indices to get weights from + weight_filter: Optional function that takes weight name and returns True if weight should be included + + Returns: + ret: Dictionary mapping device UUIDs to base64-encoded pickled handles + """ + ret = {} + device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device + + for device in device_list: + all_handles = [] + for item in self.all_weights[device]: + name, p = item + # Apply filter if provided + if weight_filter is not None and not weight_filter(name): + continue + handle = reduce_tensor(p) + all_handles.append((name, handle)) + + # Serialize with base64-encoded pickle + serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8") + ret[self.device_uuid[device]] = serialized + + return ret + def generate_batch_incremental( self, original_prompts: List[str], generated_token_ids_list: List[List[int]] ): @@ -153,11 +189,13 @@ def run_generate(llm, hf_model, prompts, sampling_params): return llm_logits, ref_logits +@pytest.mark.parametrize("use_serialized_handles", [True, False]) @pytest.mark.parametrize( "model_dir", ["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"], ) -def test_llm_update_weights(model_dir): +def test_llm_update_weights(model_dir, use_serialized_handles): + """Test LLM update_weights with both serialized and direct IPC handle formats.""" model_dir = str(llm_models_root() / model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) @@ -182,7 +220,11 @@ def test_llm_update_weights(model_dir): sampling_params = SamplingParams(temperature=0, return_generation_logits=True) - ipc_handles = hf_model.get_weight_ipc_handles([0]) + # Get IPC handles in either serialized or direct format + if use_serialized_handles: + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0]) + else: + ipc_handles = hf_model.get_weight_ipc_handles([0]) llm._collective_rpc("update_weights", (ipc_handles,)) # Finalize the update weights diff --git a/tests/unittest/llmapi/_run_mpi_comm_task.py b/tests/unittest/llmapi/_run_mpi_comm_task.py index b60b7a1efdc..94ca4d38659 100644 --- a/tests/unittest/llmapi/_run_mpi_comm_task.py +++ b/tests/unittest/llmapi/_run_mpi_comm_task.py @@ -3,6 +3,7 @@ import click +from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_hmac_key_env from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient from tensorrt_llm.llmapi.utils import print_colored @@ -15,8 +16,9 @@ def main(task_type: Literal["submit", "submit_sync"]): tasks = [0] assert os.environ[ 'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() client = RemoteMpiCommSessionClient( - os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR']) + os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'], hmac_key=hmac_key) for task in tasks: if task_type == "submit": client.submit(print_colored, f"{task}\n", "green") diff --git a/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py index 5b50df94f2d..440d07149cb 100644 --- a/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py +++ b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py @@ -3,7 +3,8 @@ import click -from tensorrt_llm.executor.utils import LlmLauncherEnvs +from tensorrt_llm.executor.utils import ( + LlmLauncherEnvs, get_spawn_proxy_process_ipc_hmac_key_env) from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient from tensorrt_llm.llmapi.utils import print_colored @@ -13,8 +14,10 @@ def run_task(task_type: Literal["submit", "submit_sync"]): assert os.environ[ LlmLauncherEnvs. TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() client = RemoteMpiCommSessionClient( - os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR]) + os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR], + hmac_key=hmac_key) for task in tasks: if task_type == "submit":