-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy pathmerge_lora_for_helios.py
More file actions
55 lines (46 loc) · 1.65 KB
/
merge_lora_for_helios.py
File metadata and controls
55 lines (46 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
from argparse import Namespace
sys.path.append("../")
from helios.modules.transformer_helios import HeliosTransformer3DModel
from helios.pipelines.pipeline_helios import HeliosPipeline
from helios.utils.utils_base import load_extra_components
transformer_additional_kwargs = {
"has_multi_term_memory_patch": True,
"zero_history_timestep": True,
"guidance_cross_attn": True,
"restrict_self_attn": False,
"is_train_restrict_lora": False,
"restrict_lora": False,
"restrict_lora_rank": 128,
}
transformer = HeliosTransformer3DModel.from_pretrained(
"1_formal_ckpts/ablation_stage3_2_mid-train_v4_e2500-ema",
subfolder="transformer",
transformer_additional_kwargs=transformer_additional_kwargs,
)
pipe = HeliosPipeline.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
transformer=transformer,
)
pipe.load_lora_weights(
"ablation_stage3_3_post-train-emergency_only-gan/checkpoint-2000/model_ema/pytorch_lora_weights.safetensors",
adapter_name="default",
)
pipe.set_adapters(["default"], adapter_weights=[1.0])
args = Namespace()
if not hasattr(args, "training_config"):
args.training_config = Namespace()
args.training_config.is_enable_stage1 = True
args.training_config.restrict_self_attn = True
args.training_config.is_amplify_history = True
args.training_config.is_use_gan = True
load_extra_components(
args,
transformer,
"ablation_stage3_3_post-train-emergency_only-gan/checkpoint-2000/model_ema/transformer_partial.pth",
)
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.transformer.save_pretrained(
"1_formal_ckpts/ablation_stage3_3_post-train-emergency_only-gan_e2000-ema/transformer"
)