Skip to content

Commit 2d9a75e

Browse files
author
Yexiong Lin
committed
Add the support for fp8 t5
1 parent 3e6df49 commit 2d9a75e

File tree

7 files changed

+63
-17
lines changed

7 files changed

+63
-17
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B
137137

138138
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
139139

140+
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
141+
142+
```
143+
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
144+
```
145+
140146
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
141147
142148

wan/configs/wan_i2v_14B.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
i2v_14B.update(wan_shared_cfg)
1111

1212
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13+
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
1314
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
1415

1516
# clip

wan/configs/wan_t2v_14B.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# t5
1212
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13+
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
1314
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
1415

1516
# vae

wan/configs/wan_t2v_1_3B.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# t5
1212
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13+
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
1314
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
1415

1516
# vae

wan/image2video.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,18 @@ def __init__(
8080
self.param_dtype = config.param_dtype
8181

8282
shard_fn = partial(shard_model, device_id=device_id)
83+
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
84+
quantization = "fp8_e4m3fn"
85+
else:
86+
quantization = "disabled"
8387
self.text_encoder = T5EncoderModel(
8488
text_len=config.text_len,
8589
dtype=config.t5_dtype,
8690
device=torch.device('cpu'),
8791
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
8892
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
8993
shard_fn=shard_fn if t5_fsdp else None,
94+
quantization=quantization,
9095
)
9196

9297
self.vae_stride = config.vae_stride
@@ -266,13 +271,15 @@ def generate(self,
266271
# preprocess
267272
if not self.t5_cpu:
268273
self.text_encoder.model.to(self.device)
269-
context = self.text_encoder([input_prompt], self.device)
270-
context_null = self.text_encoder([n_prompt], self.device)
274+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
275+
context = self.text_encoder([input_prompt], self.device)
276+
context_null = self.text_encoder([n_prompt], self.device)
271277
if offload_model:
272278
self.text_encoder.model.cpu()
273279
else:
274-
context = self.text_encoder([input_prompt], torch.device('cpu'))
275-
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
280+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
281+
context = self.text_encoder([input_prompt], torch.device('cpu'))
282+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
276283
context = [t.to(self.device) for t in context]
277284
context_null = [t.to(self.device) for t in context_null]
278285

wan/modules/t5.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
from .tokenizers import HuggingfaceTokenizer
1111

12+
from accelerate import init_empty_weights
13+
from accelerate.utils import set_module_tensor_to_device
14+
from safetensors.torch import load_file
15+
1216
__all__ = [
1317
'T5Model',
1418
'T5Encoder',
@@ -442,7 +446,7 @@ def _t5(name,
442446
model = model_cls(**kwargs)
443447

444448
# set device
445-
model = model.to(dtype=dtype, device=device)
449+
# model = model.to(dtype=dtype, device=device)
446450

447451
# init tokenizer
448452
if return_tokenizer:
@@ -479,21 +483,39 @@ def __init__(
479483
checkpoint_path=None,
480484
tokenizer_path=None,
481485
shard_fn=None,
486+
quantization="disabled",
482487
):
483488
self.text_len = text_len
484489
self.dtype = dtype
485490
self.device = device
486491
self.checkpoint_path = checkpoint_path
487492
self.tokenizer_path = tokenizer_path
488493

489-
# init model
490-
model = umt5_xxl(
491-
encoder_only=True,
492-
return_tokenizer=False,
493-
dtype=dtype,
494-
device=device).eval().requires_grad_(False)
494+
495495
logging.info(f'loading {checkpoint_path}')
496-
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
496+
if quantization == "disabled":
497+
# init model
498+
model = umt5_xxl(
499+
encoder_only=True,
500+
return_tokenizer=False,
501+
dtype=dtype,
502+
device=device).eval().requires_grad_(False)
503+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
504+
elif quantization == "fp8_e4m3fn":
505+
with init_empty_weights():
506+
model = umt5_xxl(
507+
encoder_only=True,
508+
return_tokenizer=False,
509+
dtype=dtype,
510+
device=device).eval().requires_grad_(False)
511+
cast_dtype = torch.float8_e4m3fn
512+
state_dict = load_file(checkpoint_path, device="cpu")
513+
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
514+
for name, param in model.named_parameters():
515+
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
516+
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
517+
del state_dict
518+
497519
self.model = model
498520
if shard_fn is not None:
499521
self.model = shard_fn(self.model, sync_module_states=False)

wan/text2video.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,19 @@ def __init__(
7474
self.param_dtype = config.param_dtype
7575

7676
shard_fn = partial(shard_model, device_id=device_id)
77+
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
78+
quantization = "fp8_e4m3fn"
79+
else:
80+
quantization = "disabled"
81+
7782
self.text_encoder = T5EncoderModel(
7883
text_len=config.text_len,
7984
dtype=config.t5_dtype,
8085
device=torch.device('cpu'),
8186
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
8287
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
83-
shard_fn=shard_fn if t5_fsdp else None)
88+
shard_fn=shard_fn if t5_fsdp else None,
89+
quantization=quantization)
8490

8591
self.vae_stride = config.vae_stride
8692
self.patch_size = config.patch_size
@@ -221,13 +227,15 @@ def generate(self,
221227

222228
if not self.t5_cpu:
223229
self.text_encoder.model.to(self.device)
224-
context = self.text_encoder([input_prompt], self.device)
225-
context_null = self.text_encoder([n_prompt], self.device)
230+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
231+
context = self.text_encoder([input_prompt], self.device)
232+
context_null = self.text_encoder([n_prompt], self.device)
226233
if offload_model:
227234
self.text_encoder.model.cpu()
228235
else:
229-
context = self.text_encoder([input_prompt], torch.device('cpu'))
230-
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
236+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
237+
context = self.text_encoder([input_prompt], torch.device('cpu'))
238+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
231239
context = [t.to(self.device) for t in context]
232240
context_null = [t.to(self.device) for t in context_null]
233241

0 commit comments

Comments
 (0)