Skip to content

Commit a086f0d

Browse files
committed
upd model level offload hooks
Signed-off-by: yuanheng <jonathan.zhaoyh@gmail.com>
1 parent 26b8a6b commit a086f0d

File tree

2 files changed

+123
-74
lines changed

2 files changed

+123
-74
lines changed

vllm_omni/diffusion/offloader/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
from .base import OffloadBackend, OffloadConfig, OffloadStrategy
1111
from .layerwise_backend import LayerWiseOffloadBackend
12-
from .sequential_backend import ModelLevelOffloadBackend
12+
from .sequential_backend import (
13+
ModelLevelOffloadBackend,
14+
apply_sequential_offload,
15+
remove_sequential_offload,
16+
)
1317

1418
logger = init_logger(__name__)
1519

@@ -19,6 +23,8 @@
1923
"OffloadStrategy",
2024
"LayerWiseOffloadBackend",
2125
"ModelLevelOffloadBackend",
26+
"apply_sequential_offload",
27+
"remove_sequential_offload",
2228
"get_offload_backend",
2329
]
2430

vllm_omni/diffusion/offloader/sequential_backend.py

Lines changed: 116 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch import nn
66
from vllm.logger import init_logger
77

8+
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
89
from vllm_omni.platforms import current_omni_platform
910

1011
from .base import OffloadBackend, OffloadConfig
@@ -13,45 +14,40 @@
1314
logger = init_logger(__name__)
1415

1516

16-
class SequentialOffloader:
17-
"""Sequential offloader: DiT and encoders take turns on GPU.
17+
class SequentialOffloadHook(ModelHook):
18+
"""Hook for sequential offloading with mutual exclusion.
1819
19-
Uses PyTorch's forward pre-hooks to automatically swap models:
20-
- Before encoder runs: move DiT modules to CPU, move encoder to GPU
21-
- Before DiT runs: move encoders to CPU, move active DiT to GPU
20+
When a module's forward is called, this hook offloads target modules to CPU
21+
and loads the current module to GPU.
2222
"""
2323

24+
_HOOK_NAME = "sequential_offload"
25+
2426
def __init__(
2527
self,
26-
dits: list[nn.Module],
27-
encoders: list[nn.Module],
28+
offload_targets: list[nn.Module],
2829
device: torch.device,
2930
pin_memory: bool = True,
3031
):
31-
assert all(isinstance(m, nn.Module) for m in dits), "All dits must be nn.Module"
32-
assert all(isinstance(m, nn.Module) for m in encoders), "All encoders must be nn.Module"
33-
self.dits = dits
34-
self.encoders = encoders
32+
# Modules to offload to CPU before this module runs
33+
self.offload_targets = offload_targets
3534
self.device = device
3635
self.pin_memory = pin_memory
37-
self._handles: list = []
3836

3937
def _to_cpu(self, module: nn.Module) -> None:
40-
"""Move module to CPU with optional memory pinning."""
41-
# Skip if already on CPU
38+
"""Move module to CPU."""
4239
try:
4340
param = next(module.parameters())
44-
if param.device.type == "cpu":
45-
return
4641
except StopIteration:
4742
return
4843

4944
previous_device = param.device
50-
module.to("cpu", non_blocking=True)
45+
# Skip if already on CPU
46+
if previous_device.type == "cpu":
47+
return
5148

52-
# Release allocator blocks when tensors leave the GPU.
53-
if previous_device.type != "cpu":
54-
torch.cuda.empty_cache()
49+
module.to("cpu", non_blocking=True)
50+
torch.cuda.empty_cache()
5551

5652
if self.pin_memory:
5753
for p in module.parameters():
@@ -60,67 +56,109 @@ def _to_cpu(self, module: nn.Module) -> None:
6056

6157
def _to_gpu(self, module: nn.Module) -> None:
6258
"""Move module to GPU."""
63-
# Skip if already on target device
6459
try:
60+
# Skip if already on target device
6561
if next(module.parameters()).device == self.device:
6662
return
6763
except StopIteration:
6864
return
6965

7066
module.to(self.device, non_blocking=True)
7167

72-
def _dit_pre_hook(self, module: nn.Module, args: tuple) -> None:
73-
"""Before DiT forward: offload encoders, load DiT."""
74-
for enc in self.encoders:
75-
self._to_cpu(enc)
68+
def pre_forward(self, module: nn.Module, *args, **kwargs) -> tuple[tuple, dict]:
69+
# Offload target modules to CPU
70+
for target in self.offload_targets:
71+
self._to_cpu(target)
72+
73+
# Load current module to GPU
7674
self._to_gpu(module)
7775

7876
current_omni_platform.synchronize()
7977

80-
logger.debug("Swapped: encoders -> CPU, DiT -> GPU")
81-
82-
def _encoder_pre_hook(self, module: nn.Module, args: tuple) -> None:
83-
"""Before encoder forward: offload DiT, load encoder."""
84-
for dit_mod in self.dits:
85-
self._to_cpu(dit_mod)
86-
self._to_gpu(module)
78+
logger.debug(
79+
"Swapped: %s -> CPU, %s -> GPU",
80+
[t.__class__.__name__ for t in self.offload_targets],
81+
module.__class__.__name__,
82+
)
8783

88-
current_omni_platform.synchronize()
84+
return args, kwargs
85+
86+
87+
def apply_sequential_offload(
88+
dit_modules: list[nn.Module],
89+
encoder_modules: list[nn.Module],
90+
device: torch.device,
91+
pin_memory: bool = True,
92+
) -> None:
93+
"""Apply sequential offloading hooks to DiT and encoder modules.
94+
95+
Registers hooks on modules to implement mutual-exclusion GPU allocation.
96+
- Before DiT runs, encoders are offloaded to CPU.
97+
- Before encoders run, DiT is offloaded to CPU.
98+
99+
Args:
100+
dit_modules: DiT/transformer modules to register hooks on
101+
encoder_modules: Encoder modules to register hooks on
102+
device: Target GPU device for loading
103+
pin_memory: Whether to pin CPU memory for faster transfers
104+
105+
Example:
106+
>>> apply_sequential_offload(
107+
... dit_modules=[pipeline.transformer],
108+
... encoder_modules=[pipeline.text_encoder, pipeline.vae],
109+
... device=torch.device("cuda:0"),
110+
... )
111+
>>> # Modules of pipeline now automatically swap between CPU and GPU
112+
"""
113+
# Register hooks on DiT modules (offload encoders when DiT runs)
114+
for dit_mod in dit_modules:
115+
registry = HookRegistry.get_or_create(dit_mod)
116+
hook = SequentialOffloadHook(
117+
offload_targets=encoder_modules,
118+
device=device,
119+
pin_memory=pin_memory,
120+
)
121+
registry.register_hook(SequentialOffloadHook._HOOK_NAME, hook)
122+
logger.debug("Registered offload hook for %s", dit_mod.__class__.__name__)
123+
124+
# Register hooks on encoders (offload DiTs when encoder runs)
125+
for enc in encoder_modules:
126+
registry = HookRegistry.get_or_create(enc)
127+
hook = SequentialOffloadHook(
128+
offload_targets=dit_modules,
129+
device=device,
130+
pin_memory=pin_memory,
131+
)
132+
registry.register_hook(SequentialOffloadHook._HOOK_NAME, hook)
133+
logger.debug("Registered offload hook for %s", enc.__class__.__name__)
89134

90-
logger.debug("Swapped: DiT -> CPU, encoder -> GPU")
91135

92-
def register(self) -> None:
93-
"""Register forward pre-hooks on DiT and encoders."""
94-
# Hook on each DiT-like module
95-
for dit_mod in self.dits:
96-
h = dit_mod.register_forward_pre_hook(self._dit_pre_hook)
97-
self._handles.append(h)
98-
logger.debug("Registered offload hook for %s", dit_mod.__class__.__name__)
136+
def remove_sequential_offload(modules: list[nn.Module]) -> None:
137+
"""Remove sequential offloading hooks from modules.
99138
100-
# Hook on each encoder
101-
for enc in self.encoders:
102-
h = enc.register_forward_pre_hook(self._encoder_pre_hook)
103-
self._handles.append(h)
104-
logger.debug("Registered offload hook for %s", enc.__class__.__name__)
139+
Args:
140+
modules: Modules to remove hooks from
105141
106-
def remove(self) -> None:
107-
"""Remove all hooks."""
108-
for h in self._handles:
109-
h.remove()
110-
self._handles = []
142+
Example:
143+
>>> all_modules = [*dit_modules, *encoder_modules]
144+
>>> remove_sequential_offload(all_modules)
145+
"""
146+
for module in modules:
147+
registry: HookRegistry | None = getattr(module, "_hook_registry", None)
148+
if registry is not None:
149+
registry.remove_hook(SequentialOffloadHook._HOOK_NAME)
150+
logger.debug("Removed offload hook from %s", module.__class__.__name__)
111151

112152

113153
class ModelLevelOffloadBackend(OffloadBackend):
114154
"""Model-level (sequential) offloading backend.
115155
116-
Implements mutual-exclusion offloading between DiT transformers and encoders.
117-
When encoders run, DiT is on CPU. When DiT runs, encoders are on CPU.
118-
This allows running large models that don't fit entirely on GPU.
156+
Uses SequentialOffloadHook registered via HookRegistry for automatic module swapping.
119157
"""
120158

121159
def __init__(self, config: OffloadConfig, device: torch.device):
122160
super().__init__(config, device)
123-
self._sequential_offloader: SequentialOffloader | None = None
161+
self._offload_modules: list[nn.Module] = [] # Track modules with hooks
124162

125163
def enable(self, pipeline: nn.Module) -> None:
126164
if self.enabled:
@@ -147,22 +185,28 @@ def enable(self, pipeline: nn.Module) -> None:
147185
logger.debug("Failed to move VAE to GPU: %s", exc)
148186

149187
# Initial state: keep DiT modules on CPU (encoders typically run first)
150-
for dit_mod in modules.dits:
151-
dit_mod.to("cpu")
152-
153-
torch.cuda.empty_cache()
154-
155-
if self.config.pin_cpu_memory:
156-
for dit_mod in modules.dits:
157-
for p in dit_mod.parameters():
158-
if p.data.device.type == "cpu" and not p.data.is_pinned():
159-
p.data = p.data.pin_memory()
160-
161-
# Register sequential offload hooks
162-
self._sequential_offloader = SequentialOffloader(
163-
modules.dits, modules.encoders, self.device, self.config.pin_cpu_memory
188+
# TODO: This part seems to be unnecessary, remove it after testing
189+
# for dit_mod in modules.dits:
190+
# dit_mod.to("cpu")
191+
192+
# torch.cuda.empty_cache()
193+
194+
# if self.config.pin_cpu_memory:
195+
# for dit_mod in modules.dits:
196+
# for p in dit_mod.parameters():
197+
# if p.data.device.type == "cpu" and not p.data.is_pinned():
198+
# p.data = p.data.pin_memory()
199+
200+
# Apply sequential offloading hooks
201+
apply_sequential_offload(
202+
dit_modules=modules.dits,
203+
encoder_modules=modules.encoders,
204+
device=self.device,
205+
pin_memory=self.config.pin_cpu_memory,
164206
)
165-
self._sequential_offloader.register()
207+
208+
# Track modules for cleanup
209+
self._offload_modules = [*modules.dits, *modules.encoders]
166210

167211
self.enabled = True
168212

@@ -176,9 +220,8 @@ def disable(self) -> None:
176220
if not self.enabled:
177221
return
178222

179-
if self._sequential_offloader is not None:
180-
self._sequential_offloader.remove()
181-
self._sequential_offloader = None
223+
remove_sequential_offload(self._offload_modules)
182224

225+
self._offload_modules.clear()
183226
self.enabled = False
184227
logger.info("Model-level offloading disabled")

0 commit comments

Comments
 (0)