Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ZeRO set/get APIs for NVMe offload #7046

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ad5f833
Control trace cache warnings
tjruwase Feb 15, 2025
9c4a177
Update docs
tjruwase Feb 15, 2025
8c25621
Merge branch 'master' into olruwase/control_trace_cache_warnings
tjruwase Feb 15, 2025
4fd1b05
Enable safe_get/set APIs for NVMe offload
tjruwase Feb 17, 2025
6fd12e1
Formatting fixes
tjruwase Feb 17, 2025
e23bfab
Add vectorized update API
tjruwase Feb 18, 2025
de6d8b1
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
7e250d9
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
76da050
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 21, 2025
cc6ed24
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 21, 2025
f9ecab7
PR feedback
tjruwase Feb 25, 2025
48e5ad7
PR feedback
tjruwase Feb 25, 2025
f20abc1
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 25, 2025
28ba8af
Code cleanup
tjruwase Feb 25, 2025
3872984
Merge branch 'olruwase/update_nvme_offload_states' of github.com:deep…
tjruwase Feb 25, 2025
8bc000c
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 26, 2025
6ac306a
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
f86e3ca
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
6c1ba6e
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 28, 2025
17935e9
Handle offload_states
tjruwase Feb 28, 2025
61685dc
Use new dlpack api; Formatting fixes
tjruwase Mar 3, 2025
1667758
Merge branch 'olruwase/new_dlpack_api' of github.com:deepspeedai/Deep…
tjruwase Mar 3, 2025
66b40ce
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 3, 2025
5a215cc
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 4, 2025
797bb15
Revert change
tjruwase Mar 4, 2025
f033827
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 5, 2025
044db61
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 7, 2025
b0f1391
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 10, 2025
0099333
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 11, 2025
e203e8f
Add -x to test failure/debug
loadams Mar 11, 2025
d6c2999
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Code cleanup
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
tjruwase committed Feb 25, 2025
commit 28ba8afea85f0f64978b0c791e2d81b6a8b5897d
8 changes: 0 additions & 8 deletions deepspeed/runtime/swap_tensor/optimizer_utils.py
Original file line number Diff line number Diff line change
@@ -48,7 +48,6 @@ def __init__(self, parameter, numel, base_folder):
self.tensors = []
self.param_id = OptimizerSwapper.parameter_id(parameter)
self.swap_folder = base_folder
# self.swap_paths = []
self.swapped_gradients = {}
self.unswapped_gradients = {}
self.tensor_numel = numel
@@ -67,8 +66,6 @@ def has_gradients(self):
def _add_tensors(self, tensor_list):
for t in tensor_list:
self.tensors.append(SwapTensorContext(t, self.swap_folder))
# self.tensors.append(t)
# self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))

def add_state_tensors(self, tensor_list):
self.has_state_tensors = True
@@ -86,7 +83,6 @@ def dtype(self):
def release_memory(self):
for t in self.tensors:
t.release_memory()
# tensor.data = torch.Tensor()

def get_compute_tensors(self):
return [t.compute_tensor for t in self.tensors]
@@ -123,8 +119,6 @@ def set_swap_buffers(self, buffers, aligned_numel):

for i, t in enumerate(self.tensors):
t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i])
# for t, buffer in zip(self.tensors, compute_buffers):
# t.data = buffer.data

def get_swap_gradient_buffers(self, swap_buffer):
assert self.numel() <= swap_buffer.numel()
@@ -135,7 +129,6 @@ def get_swap_gradient_paths(self):

def get_unpinned_state_tensors(self):
return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)]
# return [t for t in self.tensors if not get_accelerator().is_pinned(t)]

def read_unswapped_gradients(self, dest_buffer):
num_elem_count = 0
@@ -409,7 +402,6 @@ def _get_swap_paths(self, parameters, num_elems):
assert len(swap_info_list) == len(num_elems)

swap_paths = [info.tensors[0].swap_path for info in swap_info_list]
# swap_paths = [info.swap_paths[0] for info in swap_info_list]
return swap_paths

def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
55 changes: 0 additions & 55 deletions deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py
Original file line number Diff line number Diff line change
@@ -65,10 +65,7 @@ def release_swap_buffers(self, parameter):
swap_info = self._get_param_swap_info(parameter)
if swap_info is None:
return
# import pdb; pdb.set_trace()
swap_info.release_memory()
# for t in swap_info.tensors:
# t.data = torch.Tensor()

self.swap_buffer_manager.free(swap_info.swap_buffers)
swap_info.swap_buffers = []
@@ -81,9 +78,7 @@ def swap_in_optimizer_state(self, parameter, async_parameter=None):
self._flush_gradient_swapper(self.gradient_swapper)

required_buffer_count = swap_info.num_tensors() + (1 if swap_info.has_gradients() else 0)
# required_buffer_count = len(swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
aligned_numel = self._io_aligned_numel(swap_info.numel())
# import pdb; pdb.set_trace()
pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
count=required_buffer_count,
dtype=parameter.dtype)
@@ -94,7 +89,6 @@ def swap_in_optimizer_state(self, parameter, async_parameter=None):
self._swap_in_parameter(aio_handle=self.aio_handle,
parameter=parameter,
dest_buffers=pinned_buffers[:swap_info.num_tensors()])
# dest_buffers=pinned_buffers[:len(swap_info.tensors)])
self._stop_timer(SWAP_IN_PARAM_TIMER)
self.timer_names.add(SWAP_IN_PARAM_TIMER)

@@ -105,17 +99,10 @@ def swap_in_optimizer_state(self, parameter, async_parameter=None):
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)

def _swap_out_optimizer_state(self, swap_info):
# pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
pinned_tensors, pinned_paths = swap_info.get_swap_buffers_and_paths(True)
WRITE_TIMER = 'swap_submit_write'
self._start_timer(WRITE_TIMER)

# compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
# compute_buffers = get_sized_buffers(pinned_tensors, compute_lengths)
# swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(swap_info.tensors)
# swap_buffers = get_sized_buffers(pinned_tensors, swap_lengths)
# import pdb; pdb.set_trace()

swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
assert self.aio_handle.wait() == len(pinned_tensors)

@@ -141,7 +128,6 @@ def writeback_optimizer_state_and_gradients(self, parameter, write_opt_state, wr
self._swap_out_optimizer_state(swap_info)

if write_gradients and swap_info.has_gradients():
# import pdb; pdb.set_trace()
param_gradients = swap_info.swapped_gradients.values()
swap_buffers = [parameter.grad.narrow(0, grad.offset, grad.length) for grad in param_gradients]
swap_paths = [grad.path for grad in param_gradients]
@@ -160,26 +146,8 @@ def swap_out_optimizer_state(self, parameter, async_swap=False):

swap_bytes = sum(
[self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.get_compute_tensors()])
# swap_bytes = sum([self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors])

self._start_timer(SWAP_OUT_PARAM_TIMER)
# pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)

# WRITE_TIMER = 'swap_submit_write'
# self._start_timer(WRITE_TIMER)

# swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
# assert self.aio_handle.wait() == len(pinned_tensors)

# if len(unpinned_tensors) > 0:
# pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
# self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
# unpinned_tensors=unpinned_tensors,
# dest_paths=unpinned_paths,
# pinned_buffers=pinned_buffers)
# swap_info.swap_buffers += pinned_buffers.copy()

# self._stop_timer(WRITE_TIMER)
self._swap_out_optimizer_state(swap_info)
self.release_swap_buffers(parameter)
self._stop_timer(SWAP_OUT_PARAM_TIMER)
@@ -213,7 +181,6 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):

self._start_timer(READ_TIMER)
swap_in_tensors(aio_handle, swap_buffers, swap_info.get_swap_paths())
# swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
self._stop_timer(READ_TIMER)

swap_bytes = sum([buffer.numel() * buffer.element_size() for buffer in swap_buffers])
@@ -222,41 +189,19 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
aio_handle.wait()
self._stop_timer(WAIT_TIMER)

# compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
# compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
swap_info.set_swap_buffers(dest_buffers, self._io_aligned_numel(swap_info.numel()))
# for t, buffer in zip(swap_info.tensors, compute_buffers):
# t.data = buffer.data

self._log_timers([READ_TIMER, WAIT_TIMER])
if DEBUG_MODE and dist.get_rank() == 0:
logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')

# def _separate_pinned_tensors(self, swap_info):
# pinned_tensors = []
# pinned_paths = []

# unpinned_tensors = []
# unpinned_paths = []

# for tensor, path in zip(swap_info.tensors, swap_info.swap_paths):
# if get_accelerator().is_pinned(tensor.compute_tensor):
# pinned_tensors.append(tensor)
# pinned_paths.append(path)
# else:
# unpinned_tensors.append(tensor)
# unpinned_paths.append(path)

# return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths

def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)]
param_gradients = swap_info.swapped_gradients.values()
swap_buffers = [gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients]
swap_paths = [grad.path for grad in param_gradients]
SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
# import pdb; pdb.set_trace()
self._start_timer(SWAP_READ_GRADIENTS)
swap_in_tensors(aio_handle, swap_buffers, swap_paths)
self._stop_timer(SWAP_READ_GRADIENTS)
Original file line number Diff line number Diff line change
@@ -188,7 +188,6 @@ def _swap_out_optimizer_state(self, aio_handle, parameter, swap_in_op):
dst.data.copy_(unpinned_src.data)

swap_paths = param_info.get_swap_paths()
# swap_paths = param_info.swap_paths.copy()
assert len(swap_paths) == len(swap_buffers)

swap_out_tensors(aio_handle, swap_buffers, swap_paths)
@@ -221,7 +220,6 @@ def _swap_in_optimizer_state(self, aio_handle, parameter):

swap_buffers = state_buffers.copy()
swap_paths = param_info.get_swap_paths()
# swap_paths = param_info.swap_paths.copy()

if param_info.has_gradients():
parameter.grad = allocated_buffers[-1].narrow(0, 0, param_info.numel())