Skip to content

Commit 78485d5

Browse files
styusuffacebook-github-bot
authored andcommitted
Reduce GPU OOM in layer gradient computation by offloading tensors to CPU
Summary: LayerGradientXActivation is having a number of jobs with OOM erros. These errors are due to the way we drain GPU memory during the forward and backward passes to obtain layer evaluations, copies of evaluations, and layer gradients. The issue with the current way of getting evalutaions and gradients is the following: * After getting the evaluations (activations) from the forward pass, we gather all the activation tensors across multiple devices into [one device](https://www.internalfb.com/code/fbsource/[0579e5aab76b1f89fc82d27913cbb3a6e0160b5f]/fbcode/pytorch/captum/captum/_utils/common.py?lines=785) meaning that one device would hold its own layer activations + copy of all the activations across all devices. This is the peak of the memory utilization. This is in addition to storing the model graph and the original layer activations in memory. * After the backward pass on the `saved_layer` (still on GPU), we are also doing a similar operation on the gradients - collecting gradients across all devices into the first device. At this point though, we don't have as much memory utilization since backward pass is already completed. So it is safer to collect all gradients. We can then offload to cpu afterwards. What we want to do now is to offload the layer activations to cpu during these peak gpu utilization. Before we run the expensive torch.cat, we offload all these tensors to cpu first. That way, when we run torch.cat, this is actually done on cpu freeing up gpu memory. Additionally when we get to the gradients, after gathering all tensors together, we also offload these to cpu. With this simple change, we would significantly improve gpu utilization. We add a flag that will include an efficient path to the implementation. Adds a `memory_efficient` mode to `LayerGradientXActivation` and an `offload_to_cpu` parameter to `compute_layer_gradients_and_eval` to reduce peak GPU memory usage during multi-layer attribution. Differential Revision: D94915367
1 parent f48fa42 commit 78485d5

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

captum/_utils/gradient.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def compute_layer_gradients_and_eval(
555555
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
556556
output_fn: Union[None, Callable] = None,
557557
grad_kwargs: Optional[Dict[str, Any]] = None,
558+
offload_to_cpu: bool = False,
558559
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...
559560

560561

@@ -574,6 +575,7 @@ def compute_layer_gradients_and_eval(
574575
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
575576
output_fn: Union[None, Callable] = None,
576577
grad_kwargs: Optional[Dict[str, Any]] = None,
578+
offload_to_cpu: bool = False,
577579
) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]: ...
578580

579581

@@ -593,6 +595,7 @@ def compute_layer_gradients_and_eval(
593595
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
594596
output_fn: Union[None, Callable] = None,
595597
grad_kwargs: Optional[Dict[str, Any]] = None,
598+
offload_to_cpu: bool = False,
596599
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...
597600

598601

@@ -612,6 +615,7 @@ def compute_layer_gradients_and_eval(
612615
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
613616
output_fn: Union[None, Callable] = None,
614617
grad_kwargs: Optional[Dict[str, Any]] = None,
618+
offload_to_cpu: bool = False,
615619
) -> Union[
616620
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
617621
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]],
@@ -669,6 +673,7 @@ def compute_layer_gradients_and_eval(
669673
with torch.autograd.set_grad_enabled(True):
670674
# saved_layer is a dictionary mapping device to a tuple of
671675
# layer evaluations on that device.
676+
saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]]
672677
saved_layer, output = _forward_layer_distributed_eval(
673678
forward_fn,
674679
inputs,
@@ -693,35 +698,44 @@ def compute_layer_gradients_and_eval(
693698
list(next(iter(saved_layer.values())).keys()), device_ids
694699
)
695700
all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
701+
702+
def _get_layer_output(
703+
single_layer: Module, device_id: device
704+
) -> Tuple[Tensor, ...]:
705+
layer_out = saved_layer[single_layer][device_id]
706+
if output_fn is not None:
707+
layer_out = output_fn(layer_out)
708+
# When offloading to CPU, move tensors before reduction (torch.cat)
709+
# to avoid GPU OOM. This is safe because all_outputs is not used by
710+
# torch.autograd.grad, which reads from saved_layer directly.
711+
if offload_to_cpu:
712+
layer_out = tuple(t.detach().cpu() for t in layer_out)
713+
return layer_out
714+
715+
# pyre-fixme[9]: all_layers has type `List[Module]`; used as
716+
# `Union[List[Variable[ModuleOrModuleList <: [Module, List[Module]]]],
717+
# Variable[ModuleOrModuleList <: [Module, List[Module]]]]`.
718+
all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
719+
720+
# Build all_outputs before backward pass. _get_layer_output detaches
721+
# and moves tensors to CPU when offload_to_cpu is set, so these copies
722+
# do not participate in the autograd graph and won't affect GPU memory
723+
# during torch.autograd.grad (which reads from saved_layer directly).
696724
if isinstance(layer, Module):
697725
all_outputs = _reduce_list(
698-
[
699-
(
700-
saved_layer[layer][device_id]
701-
if output_fn is None
702-
else output_fn(saved_layer[layer][device_id])
703-
)
704-
for device_id in key_list
705-
]
726+
[_get_layer_output(layer, device_id) for device_id in key_list]
706727
)
707728
else:
708729
all_outputs = [
709730
_reduce_list(
710731
[
711-
(
712-
saved_layer[single_layer][device_id]
713-
if output_fn is None
714-
else output_fn(saved_layer[single_layer][device_id])
715-
)
732+
_get_layer_output(single_layer, device_id)
716733
for device_id in key_list
717734
]
718735
)
719736
for single_layer in layer
720737
]
721-
# pyre-fixme[9]: all_layers has type `List[Module]`; used as
722-
# `Union[List[Variable[ModuleOrModuleList <: [Module, List[Module]]]],
723-
# Variable[ModuleOrModuleList <: [Module, List[Module]]]]`.
724-
all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
738+
725739
grad_inputs = tuple(
726740
layer_tensor
727741
for single_layer in all_layers
@@ -750,7 +764,13 @@ def compute_layer_gradients_and_eval(
750764
output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads
751765
]
752766

753-
all_grads.append(_reduce_list(curr_saved_grads))
767+
reduced = _reduce_list(curr_saved_grads)
768+
# When offloading to CPU, move gradient tensors after reduction
769+
# (torch.cat) since reducing on GPU first is slightly more
770+
# memory-efficient than moving individual tensors before reduction.
771+
if offload_to_cpu:
772+
reduced = tuple(t.cpu() for t in reduced)
773+
all_grads.append(reduced)
754774

755775
layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
756776
layer_grads = all_grads

captum/attr/_core/layer/layer_gradient_x_activation.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
layer: ModuleOrModuleList,
2929
device_ids: Union[None, List[int]] = None,
3030
multiply_by_inputs: bool = True,
31+
memory_efficient: bool = False,
3132
) -> None:
3233
r"""
3334
Args:
@@ -62,10 +63,18 @@ def __init__(
6263
is set to True, final sensitivity scores are being multiplied by
6364
layer activations for inputs.
6465
66+
memory_efficient (bool, optional): If True, offloads intermediate
67+
activations and gradients to CPU during computation and
68+
uses in-place multiplication to reduce peak GPU memory
69+
usage. Useful when attributing across multiple target
70+
layers simultaneously.
71+
Default: False
72+
6573
"""
6674
LayerAttribution.__init__(self, forward_func, layer, device_ids)
6775
GradientAttribution.__init__(self, forward_func)
6876
self._multiply_by_inputs = multiply_by_inputs
77+
self._memory_efficient = memory_efficient
6978

7079
@property
7180
def multiplies_by_inputs(self) -> bool:
@@ -181,6 +190,7 @@ def attribute(
181190
device_ids=self.device_ids,
182191
attribute_to_layer_input=attribute_to_layer_input,
183192
grad_kwargs=grad_kwargs,
193+
offload_to_cpu=self._memory_efficient,
184194
)
185195
if isinstance(self.layer, Module):
186196
return _format_output(
@@ -191,13 +201,17 @@ def attribute(
191201
),
192202
)
193203
else:
194-
return [
195-
_format_output(
196-
len(layer_evals[i]) > 1,
197-
self.multiply_gradient_acts(layer_gradients[i], layer_evals[i]),
204+
results = []
205+
for i in range(len(self.layer)):
206+
grads_i = layer_gradients[i]
207+
evals_i = layer_evals[i]
208+
results.append(
209+
_format_output(
210+
len(evals_i) > 1,
211+
self.multiply_gradient_acts(grads_i, evals_i),
212+
)
198213
)
199-
for i in range(len(self.layer))
200-
]
214+
return results
201215

202216
def multiply_gradient_acts(
203217
self, gradients: Tuple[Tensor, ...], evals: Tuple[Tensor, ...]

0 commit comments

Comments
 (0)