Skip to content

Enhancement Request: Gradio Interface #56

@MrEdwards007

Description

@MrEdwards007

I wanted an interface that would make my life a little simpler and thought this would be a value add to others
The enhancement request allows the user to put in text for text-to-video and for image-to-video, in additional to all of the arguments used in Turbodiffusion. It shows progress and the output of the console now goes into gradio.

It performs a few checks to ensure valid combinations
The output of the program is is in the root of the programs directory called "output"
In it, contains the video plus the complete setting required to generate the same output, including seed.

I have not been able to test this at 720p, as I run out of memory on my 5090 but I believe that based on the behavior of the 480p, that it will work.

import os
import sys
import subprocess
import gradio as gr
import glob
import random
import time
import select
from datetime import datetime

# --- 1. System Setup ---
PROJECT_ROOT = "/path/to/your/Programs/TurboDiffusion"
os.chdir(PROJECT_ROOT)
os.system('clear' if os.name == 'posix' else 'cls')

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py"
I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py"

def get_gpu_status():
    """System-level GPU check."""
    try:
        res = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"],
            encoding='utf-8'
        ).strip().split(',')
        return f"🖥️ {res[0]} | ⚡ VRAM: {res[1]}MB / {res[2]}MB"
    except:
        return "🖥️ GPU Monitor Active"

def save_debug_metadata(video_path, script_rel, cmd_list):
    """Saves a fully executable reproduction script with env vars."""
    meta_path = video_path.replace(".mp4", "_metadata.txt")
    with open(meta_path, "w") as f:
        f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n")
        f.write("export PYTHONPATH=turbodiffusion\n")
        f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n")
        f.write("export TOKENIZERS_PARALLELISM=false\n\n")
        f.write(f"python {script_rel} \\\n")
        args_only = cmd_list[2:]
        for i, arg in enumerate(args_only):
            if arg.startswith("--"):
                val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else ""
                f.write(f"    {arg} {val} \\\n")

def sync_path(scale):
    fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth"
    return os.path.join(CHECKPOINT_DIR, fname)

# --- 2. Unified Generation Logic (With Safety Checks) ---

def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, pr=gr.Progress()):
    # --- PRE-FLIGHT SAFETY CHECK ---
    # Prevents users from crashing the script by mismatched configs
    error_msg = ""
    
    # 1. Quantization Check
    if mode == "T2V":
        if "quant" in dit_path.lower() and not quant:
            error_msg = "❌ CONFIG ERROR: You selected a quantized model (*quant.pth) but disabled the '8-bit' checkbox.\n\n👉 FIX: Re-enable 'Enable --quant_linear (8-bit)'."
        
        # [NEW] Attention Compatibility Check
        if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()):
            error_msg = "❌ COMPATIBILITY ERROR: You selected 'original' attention but are using a Turbo/Quantized checkpoint.\n\n👉 FIX: Switch Attention to 'sagesla' OR use a standard Wan2.1 checkpoint."

    else:
        # Check both high and low noise models for I2V
        if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant:
            error_msg = "❌ CONFIG ERROR: One of your I2V models is quantized (*quant.pth) but you disabled the '8-bit' checkbox.\n\n👉 FIX: Re-enable 'Enable --quant_linear (8-bit)'."
            
        # [NEW] Attention Compatibility Check for I2V
        if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())):
            error_msg = "❌ COMPATIBILITY ERROR: You selected 'original' attention but are using Turbo/Quantized checkpoints.\n\n👉 FIX: Switch Attention to 'sagesla' OR use standard Wan2.1 checkpoints."

    if error_msg:
        # Abort immediately, warn the user, and explain why.
        yield None, None, "❌ Config Error", "🛑 Aborted", error_msg
        return
    # -------------------------------

    actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    start_time = time.time()
    
    pr(0, desc="🚀 Starting...")
    
    if mode == "T2V":
        save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4")
        script_rel = T2V_SCRIPT
        cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)]
    else:
        save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4")
        script_rel = I2V_SCRIPT
        cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k)]
        if adapt: cmd.append("--adaptive_resolution")
        if ode: cmd.append("--ode")

    if quant: cmd.append("--quant_linear")
    if norm: cmd.append("--default_norm")
    cmd.extend(["--save_path", save_path])

    save_debug_metadata(save_path, script_rel, cmd)
    
    env = os.environ.copy()
    env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion")
    env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
    env["TOKENIZERS_PARALLELISM"] = "false"
    env["PYTHONUNBUFFERED"] = "1"

    process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
    
    partial_log = ""
    last_ui_update = 0
    
    while True:
        if process.poll() is not None:
            rest = process.stdout.read()
            if rest: partial_log += rest
            break

        reads = [process.stdout.fileno()]
        ret = select.select(reads, [], [], 0.1)

        if ret[0]:
            line = process.stdout.readline()
            partial_log += line
            
            if "Loading DiT" in line: pr(0.1, desc="⚡ Loading weights...")
            if "Sampling:" in line:
                try:
                    pct = int(line.split('%')[0].split('|')[-1].strip())
                    pr(0.2 + (pct/100 * 0.7), desc=f"🎬 Sampling: {pct}%")
                except: pass
            if "decoding" in line.lower(): pr(0.95, desc="🎥 Decoding VAE...")

        current_time = time.time()
        if current_time - last_ui_update > 0.25:
            last_ui_update = current_time
            elapsed = f"{int(current_time - start_time)}s"
            yield None, None, f"Seed: {actual_seed}", f"⏱️ Time: {elapsed}", partial_log

    history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True)
    total_time = f"{int(time.time() - start_time)}s"
    
    yield save_path, history, f"✅ Done | Seed: {actual_seed}", f"🏁 Finished in {total_time}", partial_log

# --- 3. UI Layout ---
with gr.Blocks(title="TurboDiffusion Studio") as demo:
    with gr.Row():
        gr.HTML("<h2 style='margin: 10px 0;'>⚡ TurboDiffusion Studio</h2>")
        with gr.Column(scale=1):
            gpu_display = gr.Markdown(get_gpu_status())
    
    gr.Timer(2).tick(get_gpu_status, outputs=gpu_display)

    with gr.Tabs():
        with gr.Tab("Text-to-Video"):
            with gr.Row():
                with gr.Column(scale=4):
                    t2v_p = gr.Textbox(label="Prompt", lines=3, value="A stylish woman walks down a Tokyo street...")
                    with gr.Row():
                        t2v_m = gr.Radio(["Wan2.1-1.3B", "Wan2.1-14B"], label="Model", value="Wan2.1-1.3B")
                        t2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="480p")
                        t2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
                    t2v_dit = gr.Textbox(label="DiT Path", value=sync_path("Wan2.1-1.3B"), interactive=False)
                    t2v_btn = gr.Button("Generate Video", variant="primary")
                with gr.Column(scale=3):
                    t2v_out = gr.Video(label="Result", height=320)
                    with gr.Row():
                        t2v_stat = gr.Textbox(label="Status", interactive=False, scale=2)
                        t2v_time = gr.Textbox(label="Timer", value="⏱️ Ready", interactive=False, scale=1)

        with gr.Tab("Image-to-Video"):
            with gr.Row():
                with gr.Column(scale=4):
                    with gr.Row():
                        i2v_img = gr.Image(label="Source", type="filepath", height=200)
                        i2v_p = gr.Textbox(label="Motion Prompt", lines=7)
                    with gr.Row():
                        i2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="720p")
                        i2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
                    with gr.Row():
                        i2v_adapt = gr.Checkbox(label="Adaptive Resolution", value=True)
                        i2v_ode = gr.Checkbox(label="Use ODE", value=False)
                    with gr.Accordion("I2V Path Overrides", open=False):
                        i2v_high = gr.Textbox(label="High-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-high-720P-quant.pth"))
                        i2v_low = gr.Textbox(label="Low-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-low-720P-quant.pth"))
                    i2v_btn = gr.Button("Animate Image", variant="primary")
                with gr.Column(scale=3):
                    i2v_out = gr.Video(label="Result", height=320)
                    with gr.Row():
                        i2v_stat_2 = gr.Textbox(label="Status", interactive=False, scale=2)
                        i2v_time_2 = gr.Textbox(label="Timer", value="⏱️ Ready", interactive=False, scale=1)

    console_out = gr.Textbox(label="Live CLI Console Output", lines=8, max_lines=8, interactive=False)

    with gr.Accordion("⚙️ Precision & Advanced Settings", open=False):
        with gr.Row():
            quant_opt = gr.Checkbox(label="Enable --quant_linear (8-bit)", value=True)
            steps_opt = gr.Slider(1, 4, value=4, step=1, label="Steps")
            seed_opt = gr.Number(label="Seed (0=Random)", value=0, precision=0)
        with gr.Row():
            top_k_opt = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="SLA Top-K")
            attn_opt = gr.Radio(["sagesla", "sla", "original"], label="Attention", value="sagesla")
            sigma_opt = gr.Number(label="Sigma Max", value=80)
            norm_opt = gr.Checkbox(label="Original Norms", value=False)
            frames_opt = gr.Slider(1, 120, value=77, step=1, label="Frames")

    history_gal = gr.Gallery(value=sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), reverse=True), columns=6, height="auto")

    # --- 4. Logic Bindings ---
    t2v_m.change(fn=sync_path, inputs=t2v_m, outputs=t2v_dit)
    
    t2v_args = [gr.State("T2V"), t2v_p, t2v_m, t2v_dit, gr.State(""), gr.State(""), gr.State(""), t2v_res, t2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, sigma_opt, norm_opt, gr.State(False), gr.State(False)]
    t2v_btn.click(run_gen, t2v_args, [t2v_out, history_gal, t2v_stat, t2v_time, console_out], show_progress="hidden")

    i2v_args = [i2v_img, i2v_p, gr.State("Wan2.2-A14B"), gr.State(""), i2v_high, i2v_low, i2v_img, i2v_res, i2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, gr.State(200), norm_opt, i2v_adapt, i2v_ode]
    i2v_btn.click(run_gen, i2v_args, [i2v_out, history_gal, i2v_stat_2, i2v_time_2, console_out], show_progress="hidden")

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Default(), allowed_paths=[OUTPUT_DIR])
Image Image Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions