-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathconvert_gamecraft_vae.py
More file actions
95 lines (79 loc) · 2.92 KB
/
convert_gamecraft_vae.py
File metadata and controls
95 lines (79 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# SPDX-License-Identifier: Apache-2.0
"""
Convert GameCraft VAE weights to FastVideo safetensors format.
The official GameCraft VAE (checkpoint-step-270000.ckpt or pytorch_model.pt) uses
a "vae." prefix. This script extracts and converts for FastVideo GameCraftVAE.
Usage:
python scripts/checkpoint_conversion/convert_gamecraft_vae.py \
--input Hunyuan-GameCraft-1.0/weights/stdmodels/vae_3d/hyvae/checkpoint-step-270000.ckpt \
--output official_weights/hunyuan-gamecraft/vae
"""
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
import torch
from safetensors.torch import save_file
def convert_gamecraft_vae(
input_path: Path,
output_dir: Path,
copy_config: bool = True,
) -> dict:
"""Convert official GameCraft VAE checkpoint to FastVideo format."""
ckpt = torch.load(input_path, map_location="cpu", weights_only=True)
if "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
else:
state_dict = ckpt
# Extract VAE weights and strip "vae." prefix
vae_sd = {k.replace("vae.", ""): v for k, v in state_dict.items() if k.startswith("vae.")}
print(f"Extracted {len(vae_sd)} VAE parameters")
output_dir.mkdir(parents=True, exist_ok=True)
save_file(vae_sd, output_dir / "diffusion_pytorch_model.safetensors")
print(f"Saved to {output_dir / 'diffusion_pytorch_model.safetensors'}")
if copy_config:
config_src = input_path.parent / "config.json"
if config_src.exists():
config_dst = output_dir / "config.json"
config = json.loads(config_src.read_text())
config["_class_name"] = "AutoencoderKLCausal3D"
with open(config_dst, "w") as f:
json.dump(config, f, indent=2)
print(f"Saved config to {config_dst}")
else:
print(f"Warning: config.json not found at {config_src}")
return {"total": len(vae_sd)}
def main():
parser = argparse.ArgumentParser(
description="Convert GameCraft VAE weights to FastVideo safetensors format."
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to official VAE checkpoint (.ckpt or .pt)",
)
parser.add_argument(
"--output",
type=str,
default="official_weights/hunyuan-gamecraft/vae",
help="Output directory for converted weights",
)
parser.add_argument(
"--no-config",
action="store_true",
help="Don't copy config.json",
)
args = parser.parse_args()
input_path = Path(args.input)
if not input_path.exists():
raise FileNotFoundError(f"Input not found: {input_path}")
stats = convert_gamecraft_vae(
input_path=input_path,
output_dir=Path(args.output),
copy_config=not args.no_config,
)
print(f"\nConversion complete! Total parameters: {stats['total']}")
if __name__ == "__main__":
main()