Skip to content

[question] how to get teacher model for Turbo-VAE-DC? #10

@eisneim

Description

@eisneim

i'm trying to copy DCAE class to turbo-vaed from https://github.com/hpcaitech/Open-Sora/tree/main ckpt using https://huggingface.co/hpcai-tech/Open-Sora-v2-Video-DC-AE/tree/main

device = "cuda" if torch.cuda.is_available() else "cpu"
frame_count = 17 # Tipical for video vae: factor * N + 1
dtype = torch.float16

model_config = "configs/Turbo-VAED-DC.json"
resume_from_checkpoint = "pretrained/Turbo-VAED-DC.pth"
teacher_model_name = "DCAE"
# downloaded from huggingface:  https://huggingface.co/hpcai-tech/Open-Sora-v2-Video-DC-AE/tree/main
dcae_state_dict = "path_to_/Turbo-VAED/pretrained/F32T4C128_AE.pt"
from dc_ae.models.dc_ae import dc_ae_f32, DCAE
dcae_config = dc_ae_f32("dc-ae-f32t4c128", dcae_state_dict)
teacher_model = DCAE(dcae_config).to(device, dtype=dtype)


model = AutoencoderKLTurboVAED.from_config(
    config=load_json_to_dict(model_config)
)
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu")
model.decoder.load_state_dict(checkpoint, strict=False)
model = model.to(device, dtype=dtype)

when i run i got

时间维度必须是下采样因子的倍数....

i was wonder if i'm using the correct DC-AE model

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions