|
| 1 | +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | + |
| 3 | +"""Lazy-initialized symmetric memory manager for inference. |
| 4 | +
|
| 5 | +Provides a registry of SymmetricMemoryBuffer instances keyed by a |
| 6 | +user-supplied identifier (e.g. "tp", "ep"). Buffers are created on first |
| 7 | +access so that callers never need to worry about initialization ordering |
| 8 | +relative to the inference context. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import operator |
| 14 | +from functools import reduce |
| 15 | +from typing import Optional |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +try: |
| 20 | + import torch.distributed._symmetric_memory as symm_mem |
| 21 | + |
| 22 | + HAVE_TORCH_SYMM_MEM = True |
| 23 | +except ImportError: |
| 24 | + HAVE_TORCH_SYMM_MEM = False |
| 25 | + |
| 26 | +try: |
| 27 | + import triton # pylint: disable=unused-import |
| 28 | + |
| 29 | + HAVE_TRITON = True |
| 30 | +except ImportError: |
| 31 | + HAVE_TRITON = False |
| 32 | + |
| 33 | + |
| 34 | +class SymmetricMemoryBuffer: |
| 35 | + """ |
| 36 | + symmetric memory buffer used in inference. |
| 37 | + This buffer is used by mcore-inference's low-latency |
| 38 | + NVLS all-gather and reduce-scatter collectives. |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, size_in_mb, process_group): |
| 42 | + if not HAVE_TORCH_SYMM_MEM or not HAVE_TRITON: |
| 43 | + # This should be hit if the user is running an older |
| 44 | + # version of torch, or if they do not have triton |
| 45 | + # installed. |
| 46 | + self.symm_buffer = None |
| 47 | + self.symm_mem_hdl = None |
| 48 | + else: |
| 49 | + numel = int(size_in_mb * 1024 * 1024) # size in bytes |
| 50 | + try: |
| 51 | + symm_mem.enable_symm_mem_for_group(process_group.group_name) |
| 52 | + self.symm_buffer = symm_mem.empty(numel, dtype=torch.uint8, device='cuda') |
| 53 | + self.symm_mem_hdl = symm_mem.rendezvous(self.symm_buffer, process_group) |
| 54 | + except RuntimeError as e: |
| 55 | + # If symmetric memory initialization fails, set buffer and handle to None |
| 56 | + # This should happen if the process group is not contained within NVlink |
| 57 | + self.symm_buffer = None |
| 58 | + self.symm_mem_hdl = None |
| 59 | + |
| 60 | + def _can_allocate(self, numel, dtype) -> bool: |
| 61 | + """ |
| 62 | + Returns whether enough symmetric memory is available |
| 63 | + for the given tensor shape and dtype. |
| 64 | + """ |
| 65 | + if self.symm_mem_hdl is None: |
| 66 | + return False |
| 67 | + size_of_dtype = torch.tensor([], dtype=dtype).element_size() |
| 68 | + required_len = numel * size_of_dtype |
| 69 | + return required_len <= self.symm_buffer.numel() |
| 70 | + |
| 71 | + def _allocate(self, numel, dtype) -> torch.Tensor: |
| 72 | + """ |
| 73 | + Allocates a sub-tensor from the self.symm_buffer for the given numel and dtype""" |
| 74 | + required_bytes = numel * torch.tensor([], dtype=dtype).element_size() |
| 75 | + return self.symm_buffer[0:required_bytes].view(dtype).view(numel) |
| 76 | + |
| 77 | + def maybe_get_tensors(self, tensor_specs, alignment=16): |
| 78 | + """ |
| 79 | + Pack multiple tensors contiguously in the symmetric buffer with alignment. |
| 80 | +
|
| 81 | + Each tensor's starting offset is aligned to `alignment` bytes (default 16 |
| 82 | + for 128-bit multimem access). |
| 83 | +
|
| 84 | + Args: |
| 85 | + tensor_specs: list of (numel, dtype) tuples. |
| 86 | + alignment: byte alignment for each tensor's start offset (default 16). |
| 87 | +
|
| 88 | + Returns: |
| 89 | + {"handle": None, "tensors": None} if unavailable or insufficient space. |
| 90 | + {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} |
| 91 | + on success, where raw_byte_view is a uint8 slice of the buffer. |
| 92 | + """ |
| 93 | + _NONE_RESULT = {"handle": None, "tensors": None} |
| 94 | + if self.symm_mem_hdl is None: |
| 95 | + return _NONE_RESULT |
| 96 | + |
| 97 | + # Compute aligned byte sizes and running offsets |
| 98 | + slices = [] |
| 99 | + current_offset = 0 |
| 100 | + for numel, dtype in tensor_specs: |
| 101 | + nbytes = numel * torch.tensor([], dtype=dtype).element_size() |
| 102 | + aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment |
| 103 | + slices.append((current_offset, nbytes)) |
| 104 | + current_offset += aligned_nbytes |
| 105 | + |
| 106 | + if not self._can_allocate(current_offset, torch.uint8): |
| 107 | + return _NONE_RESULT |
| 108 | + |
| 109 | + tensors = [] |
| 110 | + for offset, nbytes in slices: |
| 111 | + tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) |
| 112 | + |
| 113 | + return {"handle": self.symm_mem_hdl, "tensors": tensors} |
| 114 | + |
| 115 | + def maybe_get_tensor(self, tensor_shape, dtype): |
| 116 | + """ |
| 117 | + Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. |
| 118 | + If enough symmetric memory is not available, returns None. |
| 119 | + """ |
| 120 | + if self.symm_mem_hdl is None: |
| 121 | + return {"tensor": None, "handle": None} |
| 122 | + numel = reduce(operator.mul, tensor_shape, 1) |
| 123 | + if not self._can_allocate(numel, dtype): |
| 124 | + return {"tensor": None, "handle": None} |
| 125 | + return { |
| 126 | + "tensor": self._allocate(numel, dtype).view(*tensor_shape), |
| 127 | + "handle": self.symm_mem_hdl, |
| 128 | + } |
| 129 | + |
| 130 | + |
| 131 | +class SymmetricMemoryManager: |
| 132 | + """Registry of lazily-initialized symmetric memory buffers. |
| 133 | +
|
| 134 | + Usage:: |
| 135 | +
|
| 136 | + buf = SymmetricMemoryManager.get_buffer("tp", process_group=tp_group) |
| 137 | + result = buf.maybe_get_tensor(shape, dtype) |
| 138 | + """ |
| 139 | + |
| 140 | + _buffers: dict[str, SymmetricMemoryBuffer] = {} |
| 141 | + _default_size_mb: int = 256 |
| 142 | + |
| 143 | + @classmethod |
| 144 | + def get_buffer( |
| 145 | + cls, |
| 146 | + key: str, |
| 147 | + process_group: Optional[torch.distributed.ProcessGroup] = None, |
| 148 | + size_mb: Optional[int] = None, |
| 149 | + ) -> SymmetricMemoryBuffer: |
| 150 | + """Return the buffer for *key*, creating it on first call. |
| 151 | +
|
| 152 | + Args: |
| 153 | + key: Unique identifier (e.g. "tp", "ep"). |
| 154 | + process_group: Required on the first call for a given key. |
| 155 | + Subsequent calls may omit it. |
| 156 | + size_mb: Buffer size in MiB (default 256). |
| 157 | + """ |
| 158 | + if key not in cls._buffers: |
| 159 | + assert ( |
| 160 | + process_group is not None |
| 161 | + ), f"SymmetricMemoryManager: process_group is required on first access for key='{key}'" |
| 162 | + cls._buffers[key] = SymmetricMemoryBuffer( |
| 163 | + size_in_mb=size_mb or cls._default_size_mb, process_group=process_group |
| 164 | + ) |
| 165 | + return cls._buffers[key] |
| 166 | + |
| 167 | + @classmethod |
| 168 | + def destroy(cls, key: Optional[str] = None) -> None: |
| 169 | + """Destroy one or all buffers. |
| 170 | +
|
| 171 | + Args: |
| 172 | + key: If provided, destroy only that buffer. Otherwise destroy all. |
| 173 | + """ |
| 174 | + if key is not None: |
| 175 | + cls._buffers.pop(key, None) |
| 176 | + else: |
| 177 | + cls._buffers.clear() |
| 178 | + |
| 179 | + @classmethod |
| 180 | + def is_initialized(cls, key: str) -> bool: |
| 181 | + """Check whether a buffer has been created for *key*.""" |
| 182 | + return key in cls._buffers |
0 commit comments