The Weight Propagation Interface (WPI) is a Kubernetes-native orchestration framework designed to enable high-speed, zero-copy movement of large ML model weights between AI accelerators (GPUs, TPUs) across nodes in a cluster.
As models grow to hundreds of billions of parameters, the traditional path of saving weights to shared storage (like NFS or GCS) and having each pod independently download and load them into GPU RAM becomes a severe bottleneck. WPI solves this by treating Model Weights as first-class scheduling and hardware resources, leveraging native hardware interconnects (like NVLink and InfiniBand via NCCL) and Dynamic Resource Allocation (DRA) patterns to securely and efficiently distribute weights directly into accelerator memory.
WPI consists of three main architectural layers:
- The API / CRDs (The Abstraction): Kubernetes Custom Resources that define logical blocks of weights and how workloads bind to them.
- The WPI Operator (The Brain): A Kubernetes controller that reconciles the desired distribution of weights with the cluster's physical topology.
- The WPI Driver / Node Agent (The Mover): A privileged daemonset running on accelerator nodes that executes hardware-specific commands (CUDA IPC, NCCL, TPUNode) to allocate, share, and transmit memory.
- The Consumer (The ML Workload): The ML framework (e.g., PyTorch, vLLM) that natively binds to the shared weight memory without allocating a duplicate copy.
WPI introduces two core custom resources acting as the Control Plane Interface:
-
WeightBuffer(Cluster level): Represents a global, logical reservation of model weights.- Spec fields:
provider: Instructs the backend on who is managing the allocation (e.g.,wpi-driver).size: The total byte size requested.sourcePath: The location from which to initially stage the data (e.g., a path to Safetensors).layout: Describes the tensor layout.retentionPolicy: Determines if the weights persist beyond the lifecycle of requesting pods.sharding(optional): Enables automatic model sharding. Contains:strategy: One ofTensorParallel,ExpertParallel,PipelineParallel, orCustom.numShards: Number of shards to split the model into.shardFiles(optional): Explicit list of{index, path, sizeBytes}for pre-split models.filePattern(optional): Pattern for auto-discovering shard files (e.g.,model-{index:05d}-of-{total:05d}.safetensors).
- Status fields (populated by operator):
totalShards: Discovered number of shards.discoveredShards: List of{index, path, sizeBytes, offsetBytes}resolved from the sharding spec.
- Spec fields:
-
WeightClaim(Namespace level): Represents a specific pod or job's request to use aWeightBuffer. This mirrors the PersistentVolumeClaim (PVC) or DRA resource claim pattern.- Spec fields:
sourceBuffer: A reference to the underlyingWeightBuffer.propagationPolicy: Defines caching or locality requirements.targetLayout: Can request reshaping of the tensor layout for the specific consumer.shardIndex(optional): Which shard this claim requests. If omitted, the operator auto-assigns based on pod annotations (wpi.sig.k8s.io/shard-index,batch.kubernetes.io/job-completion-index, orray.io/rank).
- Status fields (populated by operator):
assignedShardIndex: The resolved shard index for this claim.
- Spec fields:
The WPI Operator watches for WeightBuffer and WeightClaim objects. When a pod is scheduled that references a WeightClaim, the Operator initiates the provisioning process. It makes gRPC calls to the appropriate node's WPI Driver to establish the memory and authorize connections.
The driver runs as a DaemonSet (typically in the wpi-system namespace) on nodes with accelerators. It implements three primary gRPC services defined in wpi.proto:
Used by the operator to discover the driver's capabilities (e.g., whether it supports ON_THE_FLY_RESHAPING or CROSS_VENDOR_TRANSFER).
Handles cluster-level abstractions.
CreateWeightBuffer: Instructs the system to track a new buffer request.AuthorizeTransfer: Validates that a requested weight transfer mapping is authorized.QueryTopology: Retrieves accelerator interconnect topology (e.g., NVLink hierarchies) to optimize scheduling.
The core data-plane interface responsible for moving and exposing memory.
-
NodeStageWeight:- Interacts directly with driver APIs (e.g.,
libcuda). - Modifies memory access (
cuMemCreate,cuMemExportToShareableHandle,cuMemAddressReserve,cuMemMap). - If a
source_pathis provided, parses the weights (e.g., Safetensors) and initiates a high-bandwidth zero-copyHostToDevicecopy directly into the VRAM block. - Starts a background thread hosting a UNIX socket (e.g.,
/run/wpi/sockets/<buffer_id>.sock). This socket usesSCM_RIGHTSsendmsgto pass the raw OS File Descriptor (FD) to a connecting consumer.
- Interacts directly with driver APIs (e.g.,
-
NodePropagate: Initiates multi-node weight transfer. Supports two modes:- BROADCAST (default): Rank 0 (the source) uses
ncclBcastto push the full VRAM buffer over the high-speed network (InfiniBand/RoCE) to all target nodes simultaneously. All targets receive identical data. - SCATTER: Each target receives a different shard of the buffer. The source uses
ncclSend/ncclRecvgroup operations to send specific byte ranges (offset_bytes:offset_bytes+length_bytes) to specific targets. This enables distributing a single large model across multiple nodes with each receiving only its assigned shard.
The propagation mode and shard assignments are communicated to targets via the TCP handshake that precedes NCCL operations. The driver also sends
PRE_UPDATEnotifications to local consumers before overwriting buffer contents, allowing them to flush caches (e.g., KV cache). - BROADCAST (default): Rank 0 (the source) uses
-
NodeRegisterWeight: Returns the internal DMA-buf or shareable handle ID. -
NodeTranslateAndMap: Returns the device path (UNIX Socket path) to be mounted into the consumer pod.
The consumer is standard ML framework code deployed as a pod, with a few modifications to support WPI memory mapping.
- Initial Connection: The pod waits for the WPI UNIX socket to appear in its mounted directory (
/run/wpi/sockets). - FD Reception: It connects to the socket and receives the shared File Descriptor via ancillary data (
recvmsg). - CUDA Import: Using
libcudavia ctypes, the consumer callscuMemImportFromShareableHandle(fd), reserves an address space (cuMemAddressReserve), and maps the memory (cuMemMap). - Zero-Copy Wrapping: Using PyTorch's
__cuda_array_interface__, the raw physical GPU pointer is exposed as a native PyTorch Tensor.class RawCUDATensor: def __init__(self, ptr, nbytes): self.__cuda_array_interface__ = { "shape": (nbytes,), "typestr": "|u1", "data": (ptr, False), "version": 3 } tensor = torch.as_tensor(RawCUDATensor(device_ptr, size), device='cuda:0') weights = tensor.view(torch.float16)
This bypasses memory allocation inside the pod entirely. The weights exist in the memory allocated by the WPI driver, and the PyTorch process merely holds a reference to that physical memory.
- Cluster Admin: Creates a
WeightBufferpointing to an NFS directory of Safetensors. - Operator: Sees the
WeightBufferand waits for consumers. - ML Job: Creates a Pod with a
WeightClaimfor that buffer. - Operator/Scheduler: Schedules the Pod to
Node A. CallsNodeStageWeightonNode A's driver. - Driver on Node A:
- Allocates contiguous VRAM.
- Reads the Safetensors from NFS directly into the VRAM chunk.
- Exports the memory as an FD and creates an
SCM_RIGHTSUNIX socket.
- Pod on Node A: Starts up, queries the WPI socket, gets the FD, maps it using
cuMemImportFromShareableHandle, wraps the pointer in a framework tensor, and begins inference instantly.
- ML Distributed Job: A massive inference job is scheduled across Nodes A, B, C, and D.
- Operator: Detects that the model weights are currently only staged on
Node A. It pre-allocates uninitialized memory on Nodes B, C, and D by callingNodeStageWeight(without a source path). - Operator: Calls
NodePropagateonNode A's WPI driver, providing the IPs of nodes B, C, and D as targets. - Drivers: Node A spawns an NCCL Unique ID. Nodes B, C, and D connect via TCP to get the ID and their collective Rank.
- NCCL Operation: Node A calls
ncclBcast(). The weights fly over the RDMA backend at ~200+ Gbps switch fabric speeds simultaneously into the VRAM of nodes B, C, and D. - Pods: All pods across all 4 nodes map the memory locally and begin distributed inference.
- P2P Topology Awareness: Integrating with PCIe/NVLink topology APIs to perfectly schedule readers based on hardware distance (NUMA node, NVLink bridge).
- TPU Support: Expanding the memory abstraction from
libcudaandcuMemover tolibtpuequivalent memory export constructs. - On-the-Fly Reshaping: Implementing custom CUDA kernels in the WPI driver to transpose or shard memory blocks (Column-major to Row-major) dynamically as it is passed to the consumer, accommodating different framework needs from the same physical copy.
WPI supports first-class model sharding, enabling a single WeightBuffer to be automatically distributed across multiple GPUs and nodes. This is critical for models that exceed single-GPU memory (e.g., Kimi K2 at 1T parameters, Llama 405B).
| Strategy | Description | Use Case |
|---|---|---|
TensorParallel |
Even byte-range splits across GPUs | Most common; each GPU gets a contiguous slice |
ExpertParallel |
Expert-to-GPU mapping for MoE models | Kimi K2 (384 experts), Mixtral |
PipelineParallel |
Layer blocks assigned to different stages | Deep models split by layer groups |
Custom |
User-defined shard-to-file mapping | Pre-sharded checkpoints |
The WPI operator resolves shards via three paths (in priority order):
- Explicit
shardFiles— user provides exact{index, path, sizeBytes}per shard filePattern— operator constructs paths likemodel-00001-of-00008.safetensors- Capacity split — operator divides
size / numShardsinto even byte ranges withoffsetBytes
The driver tracks sharded buffers using a naming convention: <buffer_id>__shard_<N>. Each shard gets its own:
- VRAM allocation
- FD-passing UNIX socket (
<buffer_id>__shard_<N>.sock) - Notification socket (
<buffer_id>__shard_<N>_notify.sock)
This is transparent to the consumer — the WPIClient resolves the correct shard-scoped ID automatically.
- Admin: Creates a
WeightBufferwithsharding: {strategy: TensorParallel, numShards: 8}. - Operator: Discovers 8 shards, populates
status.discoveredShardswith paths/sizes. - Job: Creates 8
WeightClaimobjects. Shard indices are auto-assigned from pod annotations. - Driver (source): Stages the full model into a single VRAM buffer.
- Propagation: The operator calls
NodePropagatewithmode: SCATTERandshard_assignmentsmapping each target to its byte range. - Driver (targets): Each target receives only its shard via
ncclRecv, deposits it at the correct offset in its local buffer. - Pods: Each pod maps its local shard and begins parallel execution.
| Component | Specification |
|---|---|
| Cluster | GKE (us-central1), 2× A4 nodes |
| GPUs | 8× NVIDIA GPUs per node (A100-equivalent) |
| Interconnect | InfiniBand with RDMA/GDR (GPUDirect RDMA) |
| NCCL | With IB plugin, adaptive routing enabled |
| Kubernetes | DRA (Dynamic Resource Allocation) for GPU memory scheduling |
Full 8-GPU concurrent transfer: each of the 8 source GPUs sends its 75 GB shard to the corresponding target GPU on the remote node over a dedicated NCCL communicator. All 8 streams run simultaneously.
Configuration:
WeightBuffer: 600 GiB,TensorParallel,numShards: 8- 16 DRA pods total (8 source + 8 target)
- 8 independent NCCL communicators (world_size=2 each)
Results (best of 3 runs):
| Metric | Value |
|---|---|
| Total Data Transferred | 600 GB |
| Number of Concurrent Streams | 8 |
| Shard Size | 75 GB |
| Average Stream Latency | 2.31 s |
| Max Stream Latency (wall clock) | 2.39 s |
| Per-Stream Bandwidth (avg) | 32.43 GB/s |
| Aggregate Cross-Node Throughput | 251 GB/s |
Per-Shard Breakdown:
| Shard | Latency (s) | Bandwidth (GB/s) |
|---|---|---|
| 0 | 2.2654 | 33.11 |
| 1 | 2.2887 | 32.77 |
| 2 | 2.2536 | 33.28 |
| 3 | 2.3122 | 32.44 |
| 4 | 2.3310 | 32.17 |
| 5 | 2.3901 | 31.38 |
| 6 | 2.3543 | 31.86 |
| 7 | 2.3034 | 32.56 |
- Near-linear scaling: 8 concurrent NCCL streams achieve ~251 GB/s aggregate, close to the theoretical 8× single-stream bandwidth. The IB fabric distributes traffic evenly across NIC ports.
- Low variance: The spread between fastest (shard 2: 2.25s) and slowest (shard 5: 2.39s) is only ~140ms, indicating minimal contention.
- NCCL transport: All streams use
NET/IB/GDRDMA/Shared— zero-copy GPU-to-GPU over InfiniBand with GPUDirect RDMA, bypassing host memory entirely. - Concurrent bootstrap: Each shard's NCCL communicator initializes independently in parallel. No serialization locks are needed; the
/dev/shmvolume is sized to 1 GiB to accommodate the shared memory segments required by 8 concurrent NCCL proxy threads.