55from torch import nn
66from vllm .logger import init_logger
77
8+ from vllm_omni .diffusion .hooks import HookRegistry , ModelHook
89from vllm_omni .platforms import current_omni_platform
910
1011from .base import OffloadBackend , OffloadConfig
1314logger = init_logger (__name__ )
1415
1516
16- class SequentialOffloader :
17- """Sequential offloader: DiT and encoders take turns on GPU .
17+ class SequentialOffloadHook ( ModelHook ) :
18+ """Hook for sequential offloading with mutual exclusion .
1819
19- Uses PyTorch's forward pre-hooks to automatically swap models:
20- - Before encoder runs: move DiT modules to CPU, move encoder to GPU
21- - Before DiT runs: move encoders to CPU, move active DiT to GPU
20+ When a module's forward is called, this hook offloads target modules to CPU
21+ and loads the current module to GPU.
2222 """
2323
24+ _HOOK_NAME = "sequential_offload"
25+
2426 def __init__ (
2527 self ,
26- dits : list [nn .Module ],
27- encoders : list [nn .Module ],
28+ offload_targets : list [nn .Module ],
2829 device : torch .device ,
2930 pin_memory : bool = True ,
3031 ):
31- assert all (isinstance (m , nn .Module ) for m in dits ), "All dits must be nn.Module"
32- assert all (isinstance (m , nn .Module ) for m in encoders ), "All encoders must be nn.Module"
33- self .dits = dits
34- self .encoders = encoders
32+ # Modules to offload to CPU before this module runs
33+ self .offload_targets = offload_targets
3534 self .device = device
3635 self .pin_memory = pin_memory
37- self ._handles : list = []
3836
3937 def _to_cpu (self , module : nn .Module ) -> None :
40- """Move module to CPU with optional memory pinning."""
41- # Skip if already on CPU
38+ """Move module to CPU."""
4239 try :
4340 param = next (module .parameters ())
44- if param .device .type == "cpu" :
45- return
4641 except StopIteration :
4742 return
4843
4944 previous_device = param .device
50- module .to ("cpu" , non_blocking = True )
45+ # Skip if already on CPU
46+ if previous_device .type == "cpu" :
47+ return
5148
52- # Release allocator blocks when tensors leave the GPU.
53- if previous_device .type != "cpu" :
54- torch .cuda .empty_cache ()
49+ module .to ("cpu" , non_blocking = True )
50+ torch .cuda .empty_cache ()
5551
5652 if self .pin_memory :
5753 for p in module .parameters ():
@@ -60,67 +56,109 @@ def _to_cpu(self, module: nn.Module) -> None:
6056
6157 def _to_gpu (self , module : nn .Module ) -> None :
6258 """Move module to GPU."""
63- # Skip if already on target device
6459 try :
60+ # Skip if already on target device
6561 if next (module .parameters ()).device == self .device :
6662 return
6763 except StopIteration :
6864 return
6965
7066 module .to (self .device , non_blocking = True )
7167
72- def _dit_pre_hook (self , module : nn .Module , args : tuple ) -> None :
73- """Before DiT forward: offload encoders, load DiT."""
74- for enc in self .encoders :
75- self ._to_cpu (enc )
68+ def pre_forward (self , module : nn .Module , * args , ** kwargs ) -> tuple [tuple , dict ]:
69+ # Offload target modules to CPU
70+ for target in self .offload_targets :
71+ self ._to_cpu (target )
72+
73+ # Load current module to GPU
7674 self ._to_gpu (module )
7775
7876 current_omni_platform .synchronize ()
7977
80- logger .debug ("Swapped: encoders -> CPU, DiT -> GPU" )
81-
82- def _encoder_pre_hook (self , module : nn .Module , args : tuple ) -> None :
83- """Before encoder forward: offload DiT, load encoder."""
84- for dit_mod in self .dits :
85- self ._to_cpu (dit_mod )
86- self ._to_gpu (module )
78+ logger .debug (
79+ "Swapped: %s -> CPU, %s -> GPU" ,
80+ [t .__class__ .__name__ for t in self .offload_targets ],
81+ module .__class__ .__name__ ,
82+ )
8783
88- current_omni_platform .synchronize ()
84+ return args , kwargs
85+
86+
87+ def apply_sequential_offload (
88+ dit_modules : list [nn .Module ],
89+ encoder_modules : list [nn .Module ],
90+ device : torch .device ,
91+ pin_memory : bool = True ,
92+ ) -> None :
93+ """Apply sequential offloading hooks to DiT and encoder modules.
94+
95+ Registers hooks on modules to implement mutual-exclusion GPU allocation.
96+ - Before DiT runs, encoders are offloaded to CPU.
97+ - Before encoders run, DiT is offloaded to CPU.
98+
99+ Args:
100+ dit_modules: DiT/transformer modules to register hooks on
101+ encoder_modules: Encoder modules to register hooks on
102+ device: Target GPU device for loading
103+ pin_memory: Whether to pin CPU memory for faster transfers
104+
105+ Example:
106+ >>> apply_sequential_offload(
107+ ... dit_modules=[pipeline.transformer],
108+ ... encoder_modules=[pipeline.text_encoder, pipeline.vae],
109+ ... device=torch.device("cuda:0"),
110+ ... )
111+ >>> # Modules of pipeline now automatically swap between CPU and GPU
112+ """
113+ # Register hooks on DiT modules (offload encoders when DiT runs)
114+ for dit_mod in dit_modules :
115+ registry = HookRegistry .get_or_create (dit_mod )
116+ hook = SequentialOffloadHook (
117+ offload_targets = encoder_modules ,
118+ device = device ,
119+ pin_memory = pin_memory ,
120+ )
121+ registry .register_hook (SequentialOffloadHook ._HOOK_NAME , hook )
122+ logger .debug ("Registered offload hook for %s" , dit_mod .__class__ .__name__ )
123+
124+ # Register hooks on encoders (offload DiTs when encoder runs)
125+ for enc in encoder_modules :
126+ registry = HookRegistry .get_or_create (enc )
127+ hook = SequentialOffloadHook (
128+ offload_targets = dit_modules ,
129+ device = device ,
130+ pin_memory = pin_memory ,
131+ )
132+ registry .register_hook (SequentialOffloadHook ._HOOK_NAME , hook )
133+ logger .debug ("Registered offload hook for %s" , enc .__class__ .__name__ )
89134
90- logger .debug ("Swapped: DiT -> CPU, encoder -> GPU" )
91135
92- def register (self ) -> None :
93- """Register forward pre-hooks on DiT and encoders."""
94- # Hook on each DiT-like module
95- for dit_mod in self .dits :
96- h = dit_mod .register_forward_pre_hook (self ._dit_pre_hook )
97- self ._handles .append (h )
98- logger .debug ("Registered offload hook for %s" , dit_mod .__class__ .__name__ )
136+ def remove_sequential_offload (modules : list [nn .Module ]) -> None :
137+ """Remove sequential offloading hooks from modules.
99138
100- # Hook on each encoder
101- for enc in self .encoders :
102- h = enc .register_forward_pre_hook (self ._encoder_pre_hook )
103- self ._handles .append (h )
104- logger .debug ("Registered offload hook for %s" , enc .__class__ .__name__ )
139+ Args:
140+ modules: Modules to remove hooks from
105141
106- def remove (self ) -> None :
107- """Remove all hooks."""
108- for h in self ._handles :
109- h .remove ()
110- self ._handles = []
142+ Example:
143+ >>> all_modules = [*dit_modules, *encoder_modules]
144+ >>> remove_sequential_offload(all_modules)
145+ """
146+ for module in modules :
147+ registry : HookRegistry | None = getattr (module , "_hook_registry" , None )
148+ if registry is not None :
149+ registry .remove_hook (SequentialOffloadHook ._HOOK_NAME )
150+ logger .debug ("Removed offload hook from %s" , module .__class__ .__name__ )
111151
112152
113153class ModelLevelOffloadBackend (OffloadBackend ):
114154 """Model-level (sequential) offloading backend.
115155
116- Implements mutual-exclusion offloading between DiT transformers and encoders.
117- When encoders run, DiT is on CPU. When DiT runs, encoders are on CPU.
118- This allows running large models that don't fit entirely on GPU.
156+ Uses SequentialOffloadHook registered via HookRegistry for automatic module swapping.
119157 """
120158
121159 def __init__ (self , config : OffloadConfig , device : torch .device ):
122160 super ().__init__ (config , device )
123- self ._sequential_offloader : SequentialOffloader | None = None
161+ self ._offload_modules : list [ nn . Module ] = [] # Track modules with hooks
124162
125163 def enable (self , pipeline : nn .Module ) -> None :
126164 if self .enabled :
@@ -147,22 +185,28 @@ def enable(self, pipeline: nn.Module) -> None:
147185 logger .debug ("Failed to move VAE to GPU: %s" , exc )
148186
149187 # Initial state: keep DiT modules on CPU (encoders typically run first)
150- for dit_mod in modules .dits :
151- dit_mod .to ("cpu" )
152-
153- torch .cuda .empty_cache ()
154-
155- if self .config .pin_cpu_memory :
156- for dit_mod in modules .dits :
157- for p in dit_mod .parameters ():
158- if p .data .device .type == "cpu" and not p .data .is_pinned ():
159- p .data = p .data .pin_memory ()
160-
161- # Register sequential offload hooks
162- self ._sequential_offloader = SequentialOffloader (
163- modules .dits , modules .encoders , self .device , self .config .pin_cpu_memory
188+ # TODO: This part seems to be unnecessary, remove it after testing
189+ # for dit_mod in modules.dits:
190+ # dit_mod.to("cpu")
191+
192+ # torch.cuda.empty_cache()
193+
194+ # if self.config.pin_cpu_memory:
195+ # for dit_mod in modules.dits:
196+ # for p in dit_mod.parameters():
197+ # if p.data.device.type == "cpu" and not p.data.is_pinned():
198+ # p.data = p.data.pin_memory()
199+
200+ # Apply sequential offloading hooks
201+ apply_sequential_offload (
202+ dit_modules = modules .dits ,
203+ encoder_modules = modules .encoders ,
204+ device = self .device ,
205+ pin_memory = self .config .pin_cpu_memory ,
164206 )
165- self ._sequential_offloader .register ()
207+
208+ # Track modules for cleanup
209+ self ._offload_modules = [* modules .dits , * modules .encoders ]
166210
167211 self .enabled = True
168212
@@ -176,9 +220,8 @@ def disable(self) -> None:
176220 if not self .enabled :
177221 return
178222
179- if self ._sequential_offloader is not None :
180- self ._sequential_offloader .remove ()
181- self ._sequential_offloader = None
223+ remove_sequential_offload (self ._offload_modules )
182224
225+ self ._offload_modules .clear ()
183226 self .enabled = False
184227 logger .info ("Model-level offloading disabled" )
0 commit comments