11
11
12
12
import torch
13
13
import torch .cuda .amp as amp
14
+ import torch .nn as nn
14
15
import torch .distributed as dist
15
16
from tqdm import tqdm
16
17
22
23
get_sampling_sigmas , retrieve_timesteps )
23
24
from .utils .fm_solvers_unipc import FlowUniPCMultistepScheduler
24
25
26
+ # def convert_linear_conv_to_fp8(module):
27
+ # for name, child in module.named_children():
28
+ # # 递归处理子模块
29
+ # convert_linear_conv_to_fp8(child)
30
+
31
+ # # 判断是否为 Linear 或 Conv 层
32
+ # if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
33
+ # # 转换权重
34
+ # if hasattr(child, 'weight') and child.weight is not None:
35
+ # # 保留 Parameter 类型,仅修改数据
36
+ # child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
37
+ # # 可选:转换偏置(根据需求开启)
38
+
25
39
26
40
class WanT2V :
27
41
@@ -81,7 +95,9 @@ def __init__(
81
95
device = self .device )
82
96
83
97
logging .info (f"Creating WanModel from { checkpoint_dir } " )
84
- self .model = WanModel .from_pretrained (checkpoint_dir )
98
+ self .model = WanModel .from_pretrained (checkpoint_dir ,torch_dtype = torch .float8_e4m3fn )
99
+ #self.model = WanModel.from_pretrained(checkpoint_dir )
100
+
85
101
self .model .eval ().requires_grad_ (False )
86
102
87
103
if use_usp :
@@ -102,7 +118,9 @@ def __init__(
102
118
dist .barrier ()
103
119
if dit_fsdp :
104
120
self .model = shard_fn (self .model )
121
+ # convert_linear_conv_to_fp8(self.model)
105
122
else :
123
+ # convert_linear_conv_to_fp8(self.model)
106
124
self .model .to (self .device )
107
125
108
126
self .sample_neg_prompt = config .sample_neg_prompt
@@ -152,6 +170,7 @@ def generate(self,
152
170
- W: Frame width from size)
153
171
"""
154
172
# preprocess
173
+ offload_model = False
155
174
F = frame_num
156
175
target_shape = (self .vae .model .z_dim , (F - 1 ) // self .vae_stride [0 ] + 1 ,
157
176
size [1 ] // self .vae_stride [1 ],
@@ -225,6 +244,16 @@ def noop_no_sync():
225
244
226
245
arg_c = {'context' : context , 'seq_len' : seq_len }
227
246
arg_null = {'context' : context_null , 'seq_len' : seq_len }
247
+
248
+ # import gc
249
+ # del self.text_encoder
250
+ # del self.vae
251
+ # gc.collect() # 立即触发垃圾回收
252
+ # torch.cuda.empty_cache() # 清空CUDA缓存
253
+ # torch.cuda.reset_peak_memory_stats()
254
+
255
+ # start_mem = torch.cuda.memory_allocated()
256
+ #print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
228
257
229
258
for _ , t in enumerate (tqdm (timesteps )):
230
259
latent_model_input = latents
@@ -248,6 +277,9 @@ def noop_no_sync():
248
277
return_dict = False ,
249
278
generator = seed_g )[0 ]
250
279
latents = [temp_x0 .squeeze (0 )]
280
+
281
+ # peak_mem_bytes = torch.cuda.max_memory_allocated()
282
+ # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
251
283
252
284
x0 = latents
253
285
if offload_model :
0 commit comments