Skip to content

Commit fd99127

Browse files
author
Tianhao
committed
给wan加入fp8量化
1 parent b58b7c5 commit fd99127

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

wan/image2video.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(
9696
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
9797

9898
logging.info(f"Creating WanModel from {checkpoint_dir}")
99-
self.model = WanModel.from_pretrained(checkpoint_dir)
99+
#self.model = WanModel.from_pretrained(checkpoint_dir)
100+
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
100101
self.model.eval().requires_grad_(False)
101102

102103
if t5_fsdp or dit_fsdp or use_usp:
@@ -174,6 +175,7 @@ def generate(self,
174175
- H: Frame height (from max_area)
175176
- W: Frame width from max_area)
176177
"""
178+
offload_model = False
177179
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
178180

179181
F = frame_num
@@ -295,6 +297,18 @@ def noop_no_sync():
295297
if offload_model:
296298
torch.cuda.empty_cache()
297299

300+
301+
import gc
302+
del self.text_encoder
303+
del self.clip
304+
del self.vae
305+
gc.collect() # 立即触发垃圾回收
306+
torch.cuda.empty_cache() # 清空CUDA缓存
307+
torch.cuda.reset_peak_memory_stats()
308+
309+
start_mem = torch.cuda.memory_allocated()
310+
print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
311+
298312
self.model.to(self.device)
299313
for _, t in enumerate(tqdm(timesteps)):
300314
latent_model_input = [latent.to(self.device)]
@@ -329,6 +343,9 @@ def noop_no_sync():
329343
x0 = [latent.to(self.device)]
330344
del latent_model_input, timestep
331345

346+
peak_mem_bytes = torch.cuda.max_memory_allocated()
347+
print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
348+
332349
if offload_model:
333350
self.model.cpu()
334351
torch.cuda.empty_cache()

wan/modules/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def forward(self, x):
8080
Args:
8181
x(Tensor): Shape [B, L, C]
8282
"""
83+
if self.weight.dtype != torch.bfloat16:
84+
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
8385
return self._norm(x.float()).type_as(x) * self.weight
8486

8587
def _norm(self, x):
@@ -290,6 +292,9 @@ def forward(
290292
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
291293
"""
292294
assert e.dtype == torch.float32
295+
if self.modulation.dtype != torch.float16:
296+
# 如果不是 fp16,则转换为 fp16
297+
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
293298
with amp.autocast(dtype=torch.float32):
294299
e = (self.modulation + e).chunk(6, dim=1)
295300
assert e[0].dtype == torch.float32
@@ -337,6 +342,9 @@ def forward(self, x, e):
337342
e(Tensor): Shape [B, C]
338343
"""
339344
assert e.dtype == torch.float32
345+
if self.modulation.dtype != torch.float16:
346+
# 如果不是 fp16,则转换为 fp16
347+
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
340348
with amp.autocast(dtype=torch.float32):
341349
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
342350
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))

wan/text2video.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.cuda.amp as amp
14+
import torch.nn as nn
1415
import torch.distributed as dist
1516
from tqdm import tqdm
1617

@@ -22,6 +23,19 @@
2223
get_sampling_sigmas, retrieve_timesteps)
2324
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
2425

26+
# def convert_linear_conv_to_fp8(module):
27+
# for name, child in module.named_children():
28+
# # 递归处理子模块
29+
# convert_linear_conv_to_fp8(child)
30+
31+
# # 判断是否为 Linear 或 Conv 层
32+
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
33+
# # 转换权重
34+
# if hasattr(child, 'weight') and child.weight is not None:
35+
# # 保留 Parameter 类型,仅修改数据
36+
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
37+
# # 可选:转换偏置(根据需求开启)
38+
2539

2640
class WanT2V:
2741

@@ -81,7 +95,9 @@ def __init__(
8195
device=self.device)
8296

8397
logging.info(f"Creating WanModel from {checkpoint_dir}")
84-
self.model = WanModel.from_pretrained(checkpoint_dir)
98+
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
99+
#self.model = WanModel.from_pretrained(checkpoint_dir )
100+
85101
self.model.eval().requires_grad_(False)
86102

87103
if use_usp:
@@ -102,7 +118,9 @@ def __init__(
102118
dist.barrier()
103119
if dit_fsdp:
104120
self.model = shard_fn(self.model)
121+
# convert_linear_conv_to_fp8(self.model)
105122
else:
123+
# convert_linear_conv_to_fp8(self.model)
106124
self.model.to(self.device)
107125

108126
self.sample_neg_prompt = config.sample_neg_prompt
@@ -152,6 +170,7 @@ def generate(self,
152170
- W: Frame width from size)
153171
"""
154172
# preprocess
173+
offload_model = False
155174
F = frame_num
156175
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
157176
size[1] // self.vae_stride[1],
@@ -225,6 +244,16 @@ def noop_no_sync():
225244

226245
arg_c = {'context': context, 'seq_len': seq_len}
227246
arg_null = {'context': context_null, 'seq_len': seq_len}
247+
248+
# import gc
249+
# del self.text_encoder
250+
# del self.vae
251+
# gc.collect() # 立即触发垃圾回收
252+
# torch.cuda.empty_cache() # 清空CUDA缓存
253+
# torch.cuda.reset_peak_memory_stats()
254+
255+
# start_mem = torch.cuda.memory_allocated()
256+
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
228257

229258
for _, t in enumerate(tqdm(timesteps)):
230259
latent_model_input = latents
@@ -248,6 +277,9 @@ def noop_no_sync():
248277
return_dict=False,
249278
generator=seed_g)[0]
250279
latents = [temp_x0.squeeze(0)]
280+
281+
# peak_mem_bytes = torch.cuda.max_memory_allocated()
282+
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
251283

252284
x0 = latents
253285
if offload_model:

0 commit comments

Comments
 (0)