Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions cpp/include/tensorrt_llm/executor/transferAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "tensorrt_llm/common/assert.h"
#include <fcntl.h>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
Expand Down Expand Up @@ -194,8 +195,57 @@ using RegisterDescs = MemoryDescs;
using SyncMessage = std::string;
using ConnectionInfoType = std::string;

/// Per-region VMM chunk info used for splitting descriptors at chunk boundaries.
struct VramRegionInfo
{
size_t totalLen;
size_t chunkSize; ///< 0 = cudaMalloc (no split), >0 = VMM chunk size
};

/// Region map: virtual base address → region info.
using VramRegionMap = std::map<uintptr_t, VramRegionInfo>;

/// Backend-agnostic VMM descriptor split utilities (no NIXL dependency).
struct VmmDescSplitter
{
/// @brief Look up VMM chunk info for an address from a region map.
/// @return {chunkSize, regionBase}. Returns {0, 0} if address is not in any region.
[[nodiscard]] static std::pair<size_t, uintptr_t> lookupChunkInfo(uintptr_t addr, VramRegionMap const& regionMap);

/// @brief Split VRAM descs at chunk boundaries using a pre-built region map.
/// For non-VRAM or addresses not in the map, descs pass through unchanged.
[[nodiscard]] static MemoryDescs splitDescsWithRegionMap(MemoryDescs const& descs, VramRegionMap const& regionMap);

/// @brief Split paired src/dst descs using local and remote region maps.
/// src is split by localRegionMap, dst is split by remoteRegionMap.
/// The final piece size is min(srcPiece, dstPiece, remaining).
[[nodiscard]] static std::pair<MemoryDescs, MemoryDescs> splitTransferDescsWithRegionMaps(
MemoryDescs const& srcDescs, MemoryDescs const& dstDescs, VramRegionMap const& localRegionMap,
VramRegionMap const& remoteRegionMap);

/// @brief Split VRAM descs at VMM chunk boundaries detected via cuMemGetAddressRange.
/// For cudaMalloc memory (single allocation), descs pass through unchanged.
/// @param[out] detectedChunkSize Set to the VMM chunk size if detected, 0 otherwise.
[[nodiscard]] static MemoryDescs splitVmmDescs(MemoryDescs const& descs, size_t& detectedChunkSize);

/// @brief Build a VramRegionMap by probing each VRAM descriptor with cuMemGetAddressRange.
/// For each descriptor, detects whether it spans multiple VMM chunks and records {totalLen, chunkSize}.
/// @param descs VRAM memory descriptors to probe.
/// @return Region map with per-descriptor VMM info (chunkSize=0 for cudaMalloc memory).
[[nodiscard]] static VramRegionMap detectVramRegionMap(MemoryDescs const& descs);
};

/// VMM region metadata exchanged between agents for chunk boundary calculations.
struct VramRegionMeta
{
uintptr_t baseAddr;
size_t totalLen;
size_t chunkSize; ///< 0 = cudaMalloc (no split), >0 = VMM chunk size
};

// `AgentDesc` represents the unique identifier for reading and writing to the agent.
// By accessing this identifier, the backend can establish the correct connection.
// It also carries VMM region metadata so that remote agents can split at chunk boundaries.
class AgentDesc final
{
public:
Expand All @@ -204,13 +254,31 @@ class AgentDesc final
{
}

AgentDesc(std::string backendAgentDesc, std::vector<VramRegionMeta> vramRegions)
: mBackendAgentDesc{std::move(backendAgentDesc)}
, mVramRegions{std::move(vramRegions)}
{
}

[[nodiscard]] std::string const& getBackendAgentDesc() const noexcept
{
return mBackendAgentDesc;
}

[[nodiscard]] std::vector<VramRegionMeta> const& getVramRegions() const noexcept
{
return mVramRegions;
}

/// Serialize the entire AgentDesc (backend blob + VMM regions) into an opaque string.
[[nodiscard]] std::string serialize() const;

/// Deserialize an opaque string back into an AgentDesc.
[[nodiscard]] static AgentDesc deserialize(std::string const& data);

private:
std::string mBackendAgentDesc;
std::vector<VramRegionMeta> mVramRegions;
};

// `TransferOp` is an enumeration that represents the types of transfer operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ if(NIXL_ROOT)
# Link against all NIXL libraries
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)

# Link against CUDA
# Link against CUDA runtime (for cudaMemcpy in posix fallback)
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)

set(NIXL_ENABLED TRUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,21 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
auto const& desc = self.getBackendAgentDesc();
return nb::bytes(desc.data(), desc.size());
});
})
.def("serialize",
[](kvc::AgentDesc const& self)
{
auto s = self.serialize();
return nb::bytes(s.data(), s.size());
})
.def_static(
"deserialize",
[](nb::bytes data)
{
std::string str(data.c_str(), data.size());
return kvc::AgentDesc::deserialize(str);
},
nb::arg("data"));

// TransferRequest class
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)

void NixlTransferAgent::registerMemory(RegisterDescs const& descs)
{
// Split VRAM descriptors at VMM chunk boundaries so each sub-descriptor
// falls within a single cuMemCreate allocation (required by gdr_copy / cuda_ipc).
size_t detectedChunkSize = 0;
auto splitDescs = VmmDescSplitter::splitVmmDescs(descs, detectedChunkSize);

// Record per-desc VMM chunk info for use in deregisterMemory / submitTransferRequests
auto detectedRegionMap = VmmDescSplitter::detectVramRegionMap(descs);
mLocalVramRegionInfo.merge(detectedRegionMap);

// Coalesce contiguous memory regions to reduce registration overhead (disabled by default)
// Set TRTLLM_NIXL_ENABLE_COALESCE=1 to enable this optimization
auto coalescedDescs = common::getEnvNixlEnableCoalesce() ? NixlHelper::coalesceMemoryDescs(descs) : descs;
auto coalescedDescs = common::getEnvNixlEnableCoalesce() ? NixlHelper::coalesceMemoryDescs(splitDescs) : splitDescs;

nixl_status_t status;
status = mRawAgent->registerMem(NixlHelper::convertRegDlist(coalescedDescs), &mExtraParams);
Expand All @@ -609,13 +618,25 @@ void NixlTransferAgent::registerMemory(RegisterDescs const& descs)

void NixlTransferAgent::deregisterMemory(RegisterDescs const& descs)
{
// Split using per-region registry info to match what was registered
auto splitDescs = VmmDescSplitter::splitDescsWithRegionMap(descs, mLocalVramRegionInfo);

// Coalesce contiguous memory regions to match what was registered (disabled by default)
// Set TRTLLM_NIXL_ENABLE_COALESCE=1 to enable this optimization
auto coalescedDescs = common::getEnvNixlEnableCoalesce() ? NixlHelper::coalesceMemoryDescs(descs) : descs;
auto coalescedDescs = common::getEnvNixlEnableCoalesce() ? NixlHelper::coalesceMemoryDescs(splitDescs) : splitDescs;

nixl_status_t status;
status = mRawAgent->deregisterMem(NixlHelper::convertRegDlist(coalescedDescs), &mExtraParams);
TLLM_CHECK(status == NIXL_SUCCESS);

// Remove entries from registry
if (descs.getType() == MemoryType::kVRAM)
{
for (auto const& desc : descs.getDescs())
{
mLocalVramRegionInfo.erase(desc.getAddr());
}
}
}

void NixlTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
Expand All @@ -626,18 +647,44 @@ void NixlTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const
TLLM_CHECK(status == NIXL_SUCCESS);
TLLM_CHECK_WITH_INFO(
name == remoteName, "loadRemoteAgent gets error agent name: %s != %s", name.c_str(), remoteName.c_str());

// Store remote VMM region info for chunk boundary calculations in
// VmmDescSplitter::splitTransferDescsWithRegionMaps. Per-agent map because different remote agents may have
// overlapping virtual addresses.
auto const& regions = agentDesc.getVramRegions();
if (!regions.empty())
{
auto& remoteMap = mRemoteVramRegionInfo[name];
for (auto const& r : regions)
{
remoteMap[r.baseAddr] = {r.totalLen, r.chunkSize};
}
}
}

AgentDesc NixlTransferAgent::getLocalAgentDesc()
{
nixl_blob_t desc;
nixl_status_t status = mRawAgent->getLocalMD(desc);
nixl_blob_t nixlBlob;
nixl_status_t status = mRawAgent->getLocalMD(nixlBlob);
TLLM_CHECK(status == NIXL_SUCCESS);
return AgentDesc{desc};

// Pack local VMM region info so remote agents can compute chunk boundaries.
std::vector<VramRegionMeta> regions;
for (auto const& [base, info] : mLocalVramRegionInfo)
{
if (info.chunkSize > 0)
{
regions.push_back({base, info.totalLen, info.chunkSize});
}
}

return AgentDesc{nixlBlob, std::move(regions)};
}

void NixlTransferAgent::invalidateRemoteAgent(std::string const& name)
{
// Clean up remote VMM region info before invalidating the remote agent.
mRemoteVramRegionInfo.erase(name);
mRawAgent->invalidateRemoteMD(name);
}

Expand All @@ -656,32 +703,31 @@ void NixlTransferAgent::invalidateRemoteAgent(std::string const& name)
{
mExtraParams.hasNotif = false;
}
// Need to do this in a loop with NIXL_ERR_NOT_FOUND
// UCX AM with desc list is faster than listener thread can recv/load MD with sockets
// Will be deprecated with ETCD or callbacks
// Split transfer descriptors at VMM chunk boundaries to match registered memory.
// Both src and dst are split at chunk boundaries to ensure each descriptor
// falls within a single registered memory region on both local and remote sides.
// Find remote agent's VMM region map (empty map if not found).
static VramRegionMap const kEmptyMap;
auto remoteIt = mRemoteVramRegionInfo.find(request.getRemoteName());
auto const& remoteRegionMap = (remoteIt != mRemoteVramRegionInfo.end()) ? remoteIt->second : kEmptyMap;

auto [splitSrc, splitDst] = VmmDescSplitter::splitTransferDescsWithRegionMaps(
request.getSrcDescs(), request.getDstDescs(), mLocalVramRegionInfo, remoteRegionMap);

// Coalesce contiguous memory regions to reduce transfer count (disabled by default)
// This matches the coalescing done during registerMemory()
// Set TRTLLM_NIXL_ENABLE_COALESCE=1 to enable this optimization
if (common::getEnvNixlEnableCoalesce())
{
auto [coalescedSrc, coalescedDst]
= NixlHelper::coalesceTransferDescs(request.getSrcDescs(), request.getDstDescs());
// do
// {
auto [coalescedSrc, coalescedDst] = NixlHelper::coalesceTransferDescs(splitSrc, splitDst);
status
= mRawAgent->createXferReq(NixlHelper::convert(request.getOp()), NixlHelper::convertXferDist(coalescedSrc),
NixlHelper::convertXferDist(coalescedDst), request.getRemoteName(), handle, &mExtraParams);
// } while (status == NIXL_ERR_NOT_FOUND);
}
else
{
// do
// {
status = mRawAgent->createXferReq(NixlHelper::convert(request.getOp()),
NixlHelper::convertXferDist(request.getSrcDescs()), NixlHelper::convertXferDist(request.getDstDescs()),
request.getRemoteName(), handle, &mExtraParams);
// } while (status == NIXL_ERR_NOT_FOUND);
status = mRawAgent->createXferReq(NixlHelper::convert(request.getOp()), NixlHelper::convertXferDist(splitSrc),
NixlHelper::convertXferDist(splitDst), request.getRemoteName(), handle, &mExtraParams);
}

TLLM_CHECK_WITH_INFO(status == NIXL_SUCCESS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class NixlTransferAgent final : public BaseTransferAgent

std::vector<char> mDRamSrcBuffer;
std::vector<char> mDRamDstBuffer;

/// Local VMM region info (from registerMemory). Keyed by local virtual address.
VramRegionMap mLocalVramRegionInfo;

/// Remote VMM region info (from loadRemoteAgent). Keyed by {agentName → {addr → info}}.
/// Per-agent maps because different remote agents may have overlapping virtual addresses.
std::unordered_map<std::string, VramRegionMap> mRemoteVramRegionInfo;
};

class NixlLoopbackAgent final : public BaseLoopbackAgent
Expand Down
Loading
Loading