Skip to content

Commit 3960f97

Browse files
syed-ahmedpytorchmergebot
authored andcommitted
Documents torch.cuda.MemPool API (pytorch#148374)
Pull Request resolved: pytorch#148374 Approved by: https://github.com/eqy, https://github.com/ngimel
1 parent ed9c8a5 commit 3960f97

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

docs/source/notes/cuda.rst

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,207 @@ of the alloc/free functions that match the signatures specified above.
635635
636636
.. cublas-workspaces:
637637
638+
Mixing different CUDA system allocators in the same program
639+
-----------------------------------------------------------
640+
Depending on your use case, :meth:`~torch.cuda.change_current_allocator` may not be what you
641+
want to use, since it swaps the CUDA allocator for the entire program (similar to
642+
``PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync``). For instance, if the swapped allocator doesn't
643+
have caching mechanism, you will lose all the benefits of PyTorch's CUDACachingAllocator. Instead,
644+
you can selectively mark a region of PyTorch code to use a custom allocator using
645+
:class:`torch.cuda.MemPool`. This will let you use multiple CUDA system allocators in the same
646+
PyTorch program, along with most of the benefits of the CUDACachingAllocator (e.g. caching).
647+
Using :class:`torch.cuda.MemPool`, you can utilize custom allocators that enable several features,
648+
such as:
649+
650+
* Allocating output buffers for an all-reduce using ``ncclMemAlloc`` allocator can enable NVLink
651+
Switch Reductions (NVLS). This can reduce contention between overlapping compute and communication
652+
kernels on GPU resources (SMs, and Copy Engines), especially on tensor-parallel workloads.
653+
* For Grace CPU based systems, allocating host outputs buffers for an all-gather using ``cuMemCreate``
654+
and specifying ``CU_MEM_LOCATION_TYPE_HOST_NUMA`` can enable Extended GPU Memory (EGM) based memory transfers
655+
from source GPUs to the destination CPU. This accelerates the all-gather since the transfer
656+
happens over NVLinks, which otherwise would have happened over bandwidth-limited, Network Interface
657+
Card (NIC) links. Such an accelerated all-gather can in turn speed up model checkpointing.
658+
* If you are crafting a model and don't want to think about the optimal memory placements of a memory
659+
intensive module at first (e.g. an embedding table), or perhaps you have a module which is not
660+
performance sensitive and doesn't fit in the GPU, then you could just allocate that module with
661+
``cudaMallocManaged`` with preferred CPU location and get your model working first.
662+
663+
.. note::
664+
665+
While ``cudaMallocManaged`` offers convenient automatic memory management using CUDA Unified Virtual Memory (UVM),
666+
it is not recommended for DL workloads. For DL workloads that fit in GPU memory, explicit placement consistently
667+
outperforms UVM, since there are no page faults and access patterns remain predictable. When GPU memory gets
668+
saturated, UVM has to perform costly double transfers, evicting pages to CPU before bringing in new ones.
669+
670+
The code below shows ``ncclMemAlloc`` wrapped in a :class:`torch.cuda.memory.CUDAPluggableAllocator`.
671+
672+
.. code:: python
673+
674+
import os
675+
676+
import torch
677+
import torch.distributed as dist
678+
from torch.cuda.memory import CUDAPluggableAllocator
679+
from torch.distributed.distributed_c10d import _get_default_group
680+
from torch.utils import cpp_extension
681+
682+
683+
# create allocator
684+
nccl_allocator_source = """
685+
#include <nccl.h>
686+
#include <iostream>
687+
extern "C" {
688+
689+
void* nccl_alloc_plug(size_t size, int device, void* stream) {
690+
std::cout << "Using ncclMemAlloc" << std::endl;
691+
void* ptr;
692+
ncclResult_t err = ncclMemAlloc(&ptr, size);
693+
return ptr;
694+
695+
}
696+
697+
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
698+
std::cout << "Using ncclMemFree" << std::endl;
699+
ncclResult_t err = ncclMemFree(ptr);
700+
}
701+
702+
}
703+
"""
704+
nccl_allocator_libname = "nccl_allocator"
705+
nccl_allocator = torch.utils.cpp_extension.load_inline(
706+
name=nccl_allocator_libname,
707+
cpp_sources=nccl_allocator_source,
708+
with_cuda=True,
709+
extra_ldflags=["-lnccl"],
710+
verbose=True,
711+
is_python_module=False,
712+
build_directory="./",
713+
)
714+
715+
allocator = CUDAPluggableAllocator(
716+
f"./{nccl_allocator_libname}.so", "nccl_alloc_plug", "nccl_free_plug"
717+
).allocator()
718+
719+
# setup distributed
720+
rank = int(os.getenv("RANK"))
721+
local_rank = int(os.getenv("LOCAL_RANK"))
722+
world_size = int(os.getenv("WORLD_SIZE"))
723+
torch.cuda.set_device(local_rank)
724+
dist.init_process_group(backend="nccl")
725+
device = torch.device(f"cuda:{local_rank}")
726+
default_pg = _get_default_group()
727+
backend = default_pg._get_backend(device)
728+
729+
# Note: for convenience, ProcessGroupNCCL backend provides
730+
# the ncclMemAlloc allocator as backend.mem_allocator
731+
allocator = backend.mem_allocator
732+
733+
734+
You can now define a new memory pool by passing this allocator to :class:`torch.cuda.MemPool`:
735+
736+
.. code:: python
737+
738+
pool = torch.cuda.MemPool(allocator)
739+
740+
741+
The pool can then be used with the :class:`torch.cuda.use_mem_pool` context manager to
742+
allocate tensors into that pool:
743+
744+
.. code:: python
745+
746+
with torch.cuda.use_mem_pool(pool):
747+
# tensor gets allocated with ncclMemAlloc passed in the pool
748+
tensor = torch.arange(1024 * 1024 * 2, device=device)
749+
print(f"tensor ptr on rank {rank} is {hex(tensor.data_ptr())}")
750+
751+
# register user buffers using ncclCommRegister (called under the hood)
752+
backend.register_mem_pool(pool)
753+
754+
# Collective uses Zero Copy NVLS
755+
dist.all_reduce(tensor[0:4])
756+
torch.cuda.synchronize()
757+
print(tensor[0:4])
758+
759+
760+
Note the usage of ``register_mem_pool`` in the above example. This is an extra step for
761+
NVLS reductions, where the user buffers need to be registered with NCCL. A user can
762+
de-register the buffers with a similar ``deregister_mem_pool`` call.
763+
764+
To reclaim memory, users will first need to ensure nothing is using the pool. When none
765+
of the tensors are holding a reference to the pool, :meth:`~torch.cuda.empty_cache` will
766+
be called internally on deletion of the pool, hence returning all the memory to the system.
767+
768+
.. code:: python
769+
770+
del tensor, del pool
771+
772+
773+
The following :meth:`torch.cuda.MemPool.use_count` and :meth:`torch.cuda.MemPool.snapshot`
774+
APIs can be used for debugging purposes:
775+
776+
.. code:: python
777+
778+
pool = torch.cuda.MemPool(allocator)
779+
780+
# pool's use count should be 1 at this point as MemPool object
781+
# holds a reference
782+
assert pool.use_count() == 1
783+
784+
nelem_1mb = 1024 * 1024 // 4
785+
786+
with torch.cuda.use_mem_pool(pool):
787+
out_0 = torch.randn(nelem_1mb, device="cuda")
788+
789+
# pool's use count should be 2 at this point as use_mem_pool
790+
# holds a reference
791+
assert pool.use_count() == 2
792+
793+
# pool's use count should be back to 1 at this point as use_mem_pool
794+
# released its reference
795+
assert pool.use_count() == 1
796+
797+
with torch.cuda.use_mem_pool(pool):
798+
# pool should have 1 segment since we made a small allocation (1 MB)
799+
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
800+
assert len(pool.snapshot()) == 1
801+
802+
out_1 = torch.randn(nelem_1mb, device="cuda")
803+
804+
# pool should still have 1 segment since we made another small allocation
805+
# (1 MB) that got packed into the existing 2 MB buffer
806+
assert len(pool.snapshot()) == 1
807+
808+
out_2 = torch.randn(nelem_1mb, device="cuda")
809+
810+
# pool now should have 2 segments since the CUDACachingAllocator had
811+
# to make a new 2 MB buffer to accomodate out_2
812+
assert len(pool.snapshot()) == 2
813+
814+
815+
.. note::
816+
817+
* :class:`torch.cuda.MemPool` holds a reference to the pool. When you use the
818+
:class:`torch.cuda.use_mem_pool` context manager, it will also acquire another reference
819+
to the pool. On exit of the context manager, it will release its reference. After that,
820+
ideally it should only be tensors holding references to the pool. Once the tensors release
821+
their references, the use count of the pool will be 1, reflecting that only the
822+
:class:`torch.cuda.MemPool` object is holding a reference. Only at that point, can the memory
823+
held by the pool be returned to the system when the pool's destructor is called using
824+
``del``.
825+
* :class:`torch.cuda.MemPool` doesn't currently support ``expandable_segments`` mode of
826+
CUDACachingAllocator.
827+
* `NCCL has specific requirements`_ for a buffer to be compatible with NVLS reductions.
828+
These requirements can be broken in a dynamic workload, for instance, the buffer being
829+
sent to NCCL by the CUDACachingAllocator might be split and hence, not correctly aligned.
830+
In those cases, NCCL can use a fallback algorithm instead of NVLS.
831+
* Allocators like ``ncclMemAlloc`` can use more memory than requested, due to alignment
832+
requirements (``CU_MULTICAST_GRANULARITY_RECOMMENDED``, ``CU_MULTICAST_GRANULARITY_MINIMUM``),
833+
and can cause your workload to run out of memory.
834+
835+
.. _NCCL has specific requirements:
836+
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#memory-allocator
837+
838+
638839
cuBLAS workspaces
639840
-----------------
640841

0 commit comments

Comments
 (0)