Skip to content

Commit 462def4

Browse files
HeyangQinawan-10
andauthored
Enable hpz when running with torch.no_grad (#4232)
* enable hpz when running with torch.no_grad * change the way to detect no_grad * fix format --------- Co-authored-by: Ammar Ahmad Awan <[email protected]>
1 parent 6cbf666 commit 462def4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

deepspeed/runtime/zero/parameter_offload.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -490,19 +490,20 @@ def _run_after_backward_function(sub_module):
490490
# post backward hook
491491
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
492492

493-
@torch.no_grad()
494493
def pre_sub_module_forward_function(self, sub_module):
495494
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
496-
495+
prev_grad_state = torch.is_grad_enabled(
496+
) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
497+
torch.set_grad_enabled(False)
497498
global FWD_MODULE_STACK
498499
FWD_MODULE_STACK.append(sub_module)
499500

500501
param_coordinator = self.get_param_coordinator(training=sub_module.training)
501502
param_coordinator.trace_prologue(sub_module)
502503
if param_coordinator.is_record_trace():
503504
param_coordinator.record_module(sub_module)
504-
param_coordinator.fetch_sub_module(sub_module, forward=True)
505-
505+
param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
506+
torch.set_grad_enabled(prev_grad_state)
506507
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
507508

508509
@torch.no_grad()

0 commit comments

Comments
 (0)