Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ZeRO set/get APIs for NVMe offload #7046

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ad5f833
Control trace cache warnings
tjruwase Feb 15, 2025
9c4a177
Update docs
tjruwase Feb 15, 2025
8c25621
Merge branch 'master' into olruwase/control_trace_cache_warnings
tjruwase Feb 15, 2025
4fd1b05
Enable safe_get/set APIs for NVMe offload
tjruwase Feb 17, 2025
6fd12e1
Formatting fixes
tjruwase Feb 17, 2025
e23bfab
Add vectorized update API
tjruwase Feb 18, 2025
de6d8b1
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
7e250d9
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
76da050
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 21, 2025
cc6ed24
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 21, 2025
f9ecab7
PR feedback
tjruwase Feb 25, 2025
48e5ad7
PR feedback
tjruwase Feb 25, 2025
f20abc1
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 25, 2025
28ba8af
Code cleanup
tjruwase Feb 25, 2025
3872984
Merge branch 'olruwase/update_nvme_offload_states' of github.com:deep…
tjruwase Feb 25, 2025
8bc000c
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 26, 2025
6ac306a
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
f86e3ca
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
6c1ba6e
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 28, 2025
17935e9
Handle offload_states
tjruwase Feb 28, 2025
61685dc
Use new dlpack api; Formatting fixes
tjruwase Mar 3, 2025
1667758
Merge branch 'olruwase/new_dlpack_api' of github.com:deepspeedai/Deep…
tjruwase Mar 3, 2025
66b40ce
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 3, 2025
5a215cc
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 4, 2025
797bb15
Revert change
tjruwase Mar 4, 2025
f033827
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 5, 2025
044db61
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 7, 2025
b0f1391
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 10, 2025
0099333
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 11, 2025
e203e8f
Add -x to test failure/debug
loadams Mar 11, 2025
d6c2999
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 12, 2025
03765b8
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 17, 2025
9ecbce8
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 17, 2025
5897a8b
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/runtime/swap_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .utils import MIN_SWAPPABLE_BYTES
94 changes: 74 additions & 20 deletions deepspeed/runtime/swap_tensor/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,82 @@ def __init__(self, path, length, offset):
self.length = length


class SwapTensorContext(object):

def __init__(self, tensor, swap_folder):
self.compute_tensor = tensor
self.swap_tensor = torch.Tensor()
self.swap_path = os.path.join(swap_folder, f'{OptimizerSwapper.parameter_id(tensor)}.tensor.swp')

def release_memory(self):
self.compute_tensor.data = torch.Tensor()
self.swap_tensor.data = torch.Tensor()

def set_buffers(self, compute_buffer, swap_buffer):
self.compute_tensor.data = compute_buffer.data
self.swap_tensor.data = swap_buffer.data


class OptimizerStateSwapInfo(object):

def __init__(self, parameter, numel, base_folder):
self.tensors = []
self.param_id = OptimizerSwapper.parameter_id(parameter)
self.swap_folder = base_folder
self.swap_paths = []
# self.swap_paths = []
self.swapped_gradients = {}
self.unswapped_gradients = {}
self.tensor_numel = numel
self.tensor_dtype = parameter.dtype
self.tensor_device = parameter.device
self.has_state_tensors = False
self.swap_buffers = []
self._add_tensors([parameter])

def numel(self):
return self.tensor_numel

def has_gradients(self):
return self.swapped_gradients or self.unswapped_gradients
return bool(self.swapped_gradients) or bool(self.unswapped_gradients)

def _add_tensors(self, tensor_list):
for t in tensor_list:
self.tensors.append(t)
self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))
self.tensors.append(SwapTensorContext(t, self.swap_folder))
# self.tensors.append(t)
# self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))

def add_state_tensors(self, tensor_list):
self.has_state_tensors = True
self._add_tensors(tensor_list)

def num_tensors(self):
return len(self.tensors)

def device(self):
return self.tensor_device

def dtype(self):
return self.tensor_dtype

def release_memory(self):
for tensor in self.tensors:
tensor.data = torch.Tensor()
for t in self.tensors:
t.release_memory()
# tensor.data = torch.Tensor()

def get_compute_tensors(self):
return [t.compute_tensor for t in self.tensors]

def get_swap_paths(self):
return [t.swap_path for t in self.tensors]

def get_swap_buffers_and_paths(self, pinned):
swap_buffers = []
swap_paths = []
select_tensors = [t for t in self.tensors if get_accelerator().is_pinned(t.compute_tensor) == pinned]
for t in select_tensors:
swap_buffers.append(t.swap_tensor if pinned else t.compute_tensor)
swap_paths.append(t.swap_path)
return swap_buffers, swap_paths

def get_or_create_gradient_paths(self, offsets, lengths):
gradient_paths = []
Expand All @@ -77,11 +114,17 @@ def get_or_create_gradient_paths(self, offsets, lengths):

return gradient_paths

def set_swap_buffers(self, buffers):
compute_lengths = [self.numel()] * len(self.tensors)
def set_swap_buffers(self, buffers, aligned_numel):
num_tensors = len(self.tensors)
compute_lengths = [self.numel()] * num_tensors
compute_buffers = get_sized_buffers(buffers, compute_lengths)
for t, buffer in zip(self.tensors, compute_buffers):
t.data = buffer.data
swap_lengths = [aligned_numel] * num_tensors
swap_buffers = get_sized_buffers(buffers, swap_lengths)

for i, t in enumerate(self.tensors):
t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i])
# for t, buffer in zip(self.tensors, compute_buffers):
# t.data = buffer.data

def get_swap_gradient_buffers(self, swap_buffer):
assert self.numel() <= swap_buffer.numel()
Expand All @@ -91,7 +134,8 @@ def get_swap_gradient_paths(self):
return [grad.path for grad in self.swapped_gradients.values()]

def get_unpinned_state_tensors(self):
return [t for t in self.tensors if not get_accelerator().is_pinned(t)]
return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)]
# return [t for t in self.tensors if not get_accelerator().is_pinned(t)]

def read_unswapped_gradients(self, dest_buffer):
num_elem_count = 0
Expand All @@ -102,6 +146,15 @@ def read_unswapped_gradients(self, dest_buffer):

return num_elem_count

def write_unswapped_gradients(self, src_buffer):
num_elem_count = 0
for offset, grad_partition in self.unswapped_gradients.items():
src_tensor = src_buffer.narrow(0, offset, grad_partition.numel())
grad_partition.data.copy_(src_tensor.data)
num_elem_count += grad_partition.numel()

return num_elem_count

def release_unswapped_gradients(self):
self.unswapped_gradients = {}

Expand Down Expand Up @@ -158,10 +211,10 @@ def purge_state(self):
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False

def swappable_tensor(self, param=None, numel=None):
assert param is not None or numel is not None, "Either param or numel must be provided"
if param is not None:
return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
def is_swappable_tensor(self, tensor=None, numel=None):
assert tensor is not None or numel is not None, "Either tensor or numel must be provided"
if tensor is not None:
return self.min_aio_bytes <= (tensor.numel() * self.swap_element_size)
return self.min_aio_bytes <= (numel * self.swap_element_size)

def init_timers(self):
Expand Down Expand Up @@ -201,7 +254,7 @@ def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gra

self._start_timer(SWAP_OUT_GRADIENT_TIMER)
for tensor, offset in zip(aligned_gradients, aligned_offsets):
if not self.swappable_tensor(param=tensor):
if not self.is_swappable_tensor(tensor=tensor):
swap_info.unswapped_gradients[offset] = tensor
continue

Expand Down Expand Up @@ -355,7 +408,8 @@ def _get_swap_paths(self, parameters, num_elems):
]
assert len(swap_info_list) == len(num_elems)

swap_paths = [info.swap_paths[0] for info in swap_info_list]
swap_paths = [info.tensors[0].swap_path for info in swap_info_list]
# swap_paths = [info.swap_paths[0] for info in swap_info_list]
return swap_paths

def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
Expand Down Expand Up @@ -386,7 +440,7 @@ def _adjust_for_misaligned_lengths(self, tensors, offsets):
new_offsets = []

for orig_tensor, orig_offset in zip(tensors, offsets):
if not self.swappable_tensor(param=orig_tensor):
if not self.is_swappable_tensor(tensor=orig_tensor):
new_tensors.append(orig_tensor)
new_offsets.append(orig_offset)
continue
Expand Down Expand Up @@ -417,7 +471,7 @@ def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
self._log_timers([UNSWAPPED_READ_GRADIENTS])

# It should be safe to discard unswapped gradient partitions
swap_info.release_unswapped_gradients()
# swap_info.release_unswapped_gradients()

if SWAPPER_DEBUG_MODE:
logger.info(
Expand All @@ -430,7 +484,7 @@ def _get_state_tensors(self, parameter):

tensor_list = []
for state_name, value in self.optimizer.state[parameter].items():
if torch.is_tensor(value):
if torch.is_tensor(value) and self.is_swappable_tensor(tensor=value):
value.ds_id = state_name + '-' + parameter.ds_id
tensor_list.append(value)

Expand Down
Loading
Loading