Skip to content

Latest commit

 

History

History
105 lines (83 loc) · 4.21 KB

sana_sprint.md

File metadata and controls

105 lines (83 loc) · 4.21 KB

SANA-Sprint Logo

🏃SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation

How to Inference

1. How to use SanaSprintPipeline with 🧨diffusers

Important

It is now under construction PR

pip install git+https://github.com/huggingface/diffusers
# test sana sprint
from diffusers import SanaSprintPipeline
import torch

pipeline = SanaSprintPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
    torch_dtype=torch.bfloat16
)
pipeline.to("cuda:0")

prompt = "a tiny astronaut hatching from an egg on the moon"

image = pipeline(prompt=prompt, num_inference_steps=2).images[0]
image.save("test_out.png")

2. How to use SanaSprintPipeline in this repo

import torch
from app.sana_sprint_pipeline import SanaSprintPipeline
from torchvision.utils import save_image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(42)

sana = SanaSprintPipeline("configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml")
sana.from_pretrained("hf://Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth")

prompt = "a tiny astronaut hatching from an egg on the moon",

image = sana(
    prompt=prompt,
    height=1024,
    width=1024,
    guidance_scale=4.5,
    num_inference_steps=2,
    generator=generator,
)
save_image(image, 'sana_sprint.png', nrow=1, normalize=True, value_range=(-1, 1))

How to Train

bash train_scripts/train_scm_ladd.sh \
      configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml
      --data.data_dir="[data/toy_data]" \
      --data.type=SanaWebDatasetMS \
      --model.multi_scale=true \
      --data.load_vae_feat=true \
      --train.train_batch_size=2

Convert pth to diffusers safetensor

python scripts/convert_sana_to_diffusers.py \
      --orig_ckpt_path Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth \
      --model_type SanaSprint_1600M_P1_D20 \
      --scheduler_type scm \
      --dtype bf16 \
      --dump_path output/Sana_Sprint_1.6B_1024px_diffusers \
      --save_full_pipeline

performance

Methods (1024x1024) Inference Steps Throughput (samples/s) Latency (s) Params (B) FID 👇 CLIP 👆 GenEval 👆
Sana-Sprint_0.6B 2 6.46 0.25 0.6 6.54 28.40 0.76
Sana-Sprint-1.6B 2 5.68 0.24 1.6 6.50 28.45 0.77