-
Notifications
You must be signed in to change notification settings - Fork 150
Open
Description
I wanted an interface that would make my life a little simpler and thought this would be a value add to others
The enhancement request allows the user to put in text for text-to-video and for image-to-video, in additional to all of the arguments used in Turbodiffusion. It shows progress and the output of the console now goes into gradio.
It performs a few checks to ensure valid combinations
The output of the program is is in the root of the programs directory called "output"
In it, contains the video plus the complete setting required to generate the same output, including seed.
I have not been able to test this at 720p, as I run out of memory on my 5090 but I believe that based on the behavior of the 480p, that it will work.
import os
import sys
import subprocess
import gradio as gr
import glob
import random
import time
import select
from datetime import datetime
# --- 1. System Setup ---
PROJECT_ROOT = "/path/to/your/Programs/TurboDiffusion"
os.chdir(PROJECT_ROOT)
os.system('clear' if os.name == 'posix' else 'cls')
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output")
os.makedirs(OUTPUT_DIR, exist_ok=True)
T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py"
I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py"
def get_gpu_status():
"""System-level GPU check."""
try:
res = subprocess.check_output(
["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"],
encoding='utf-8'
).strip().split(',')
return f"🖥️ {res[0]} | ⚡ VRAM: {res[1]}MB / {res[2]}MB"
except:
return "🖥️ GPU Monitor Active"
def save_debug_metadata(video_path, script_rel, cmd_list):
"""Saves a fully executable reproduction script with env vars."""
meta_path = video_path.replace(".mp4", "_metadata.txt")
with open(meta_path, "w") as f:
f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n")
f.write("export PYTHONPATH=turbodiffusion\n")
f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n")
f.write("export TOKENIZERS_PARALLELISM=false\n\n")
f.write(f"python {script_rel} \\\n")
args_only = cmd_list[2:]
for i, arg in enumerate(args_only):
if arg.startswith("--"):
val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else ""
f.write(f" {arg} {val} \\\n")
def sync_path(scale):
fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth"
return os.path.join(CHECKPOINT_DIR, fname)
# --- 2. Unified Generation Logic (With Safety Checks) ---
def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, pr=gr.Progress()):
# --- PRE-FLIGHT SAFETY CHECK ---
# Prevents users from crashing the script by mismatched configs
error_msg = ""
# 1. Quantization Check
if mode == "T2V":
if "quant" in dit_path.lower() and not quant:
error_msg = "❌ CONFIG ERROR: You selected a quantized model (*quant.pth) but disabled the '8-bit' checkbox.\n\n👉 FIX: Re-enable 'Enable --quant_linear (8-bit)'."
# [NEW] Attention Compatibility Check
if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()):
error_msg = "❌ COMPATIBILITY ERROR: You selected 'original' attention but are using a Turbo/Quantized checkpoint.\n\n👉 FIX: Switch Attention to 'sagesla' OR use a standard Wan2.1 checkpoint."
else:
# Check both high and low noise models for I2V
if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant:
error_msg = "❌ CONFIG ERROR: One of your I2V models is quantized (*quant.pth) but you disabled the '8-bit' checkbox.\n\n👉 FIX: Re-enable 'Enable --quant_linear (8-bit)'."
# [NEW] Attention Compatibility Check for I2V
if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())):
error_msg = "❌ COMPATIBILITY ERROR: You selected 'original' attention but are using Turbo/Quantized checkpoints.\n\n👉 FIX: Switch Attention to 'sagesla' OR use standard Wan2.1 checkpoints."
if error_msg:
# Abort immediately, warn the user, and explain why.
yield None, None, "❌ Config Error", "🛑 Aborted", error_msg
return
# -------------------------------
actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
start_time = time.time()
pr(0, desc="🚀 Starting...")
if mode == "T2V":
save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4")
script_rel = T2V_SCRIPT
cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)]
else:
save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4")
script_rel = I2V_SCRIPT
cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k)]
if adapt: cmd.append("--adaptive_resolution")
if ode: cmd.append("--ode")
if quant: cmd.append("--quant_linear")
if norm: cmd.append("--default_norm")
cmd.extend(["--save_path", save_path])
save_debug_metadata(save_path, script_rel, cmd)
env = os.environ.copy()
env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion")
env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
env["TOKENIZERS_PARALLELISM"] = "false"
env["PYTHONUNBUFFERED"] = "1"
process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
partial_log = ""
last_ui_update = 0
while True:
if process.poll() is not None:
rest = process.stdout.read()
if rest: partial_log += rest
break
reads = [process.stdout.fileno()]
ret = select.select(reads, [], [], 0.1)
if ret[0]:
line = process.stdout.readline()
partial_log += line
if "Loading DiT" in line: pr(0.1, desc="⚡ Loading weights...")
if "Sampling:" in line:
try:
pct = int(line.split('%')[0].split('|')[-1].strip())
pr(0.2 + (pct/100 * 0.7), desc=f"🎬 Sampling: {pct}%")
except: pass
if "decoding" in line.lower(): pr(0.95, desc="🎥 Decoding VAE...")
current_time = time.time()
if current_time - last_ui_update > 0.25:
last_ui_update = current_time
elapsed = f"{int(current_time - start_time)}s"
yield None, None, f"Seed: {actual_seed}", f"⏱️ Time: {elapsed}", partial_log
history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True)
total_time = f"{int(time.time() - start_time)}s"
yield save_path, history, f"✅ Done | Seed: {actual_seed}", f"🏁 Finished in {total_time}", partial_log
# --- 3. UI Layout ---
with gr.Blocks(title="TurboDiffusion Studio") as demo:
with gr.Row():
gr.HTML("<h2 style='margin: 10px 0;'>⚡ TurboDiffusion Studio</h2>")
with gr.Column(scale=1):
gpu_display = gr.Markdown(get_gpu_status())
gr.Timer(2).tick(get_gpu_status, outputs=gpu_display)
with gr.Tabs():
with gr.Tab("Text-to-Video"):
with gr.Row():
with gr.Column(scale=4):
t2v_p = gr.Textbox(label="Prompt", lines=3, value="A stylish woman walks down a Tokyo street...")
with gr.Row():
t2v_m = gr.Radio(["Wan2.1-1.3B", "Wan2.1-14B"], label="Model", value="Wan2.1-1.3B")
t2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="480p")
t2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
t2v_dit = gr.Textbox(label="DiT Path", value=sync_path("Wan2.1-1.3B"), interactive=False)
t2v_btn = gr.Button("Generate Video", variant="primary")
with gr.Column(scale=3):
t2v_out = gr.Video(label="Result", height=320)
with gr.Row():
t2v_stat = gr.Textbox(label="Status", interactive=False, scale=2)
t2v_time = gr.Textbox(label="Timer", value="⏱️ Ready", interactive=False, scale=1)
with gr.Tab("Image-to-Video"):
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
i2v_img = gr.Image(label="Source", type="filepath", height=200)
i2v_p = gr.Textbox(label="Motion Prompt", lines=7)
with gr.Row():
i2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="720p")
i2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
with gr.Row():
i2v_adapt = gr.Checkbox(label="Adaptive Resolution", value=True)
i2v_ode = gr.Checkbox(label="Use ODE", value=False)
with gr.Accordion("I2V Path Overrides", open=False):
i2v_high = gr.Textbox(label="High-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-high-720P-quant.pth"))
i2v_low = gr.Textbox(label="Low-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-low-720P-quant.pth"))
i2v_btn = gr.Button("Animate Image", variant="primary")
with gr.Column(scale=3):
i2v_out = gr.Video(label="Result", height=320)
with gr.Row():
i2v_stat_2 = gr.Textbox(label="Status", interactive=False, scale=2)
i2v_time_2 = gr.Textbox(label="Timer", value="⏱️ Ready", interactive=False, scale=1)
console_out = gr.Textbox(label="Live CLI Console Output", lines=8, max_lines=8, interactive=False)
with gr.Accordion("⚙️ Precision & Advanced Settings", open=False):
with gr.Row():
quant_opt = gr.Checkbox(label="Enable --quant_linear (8-bit)", value=True)
steps_opt = gr.Slider(1, 4, value=4, step=1, label="Steps")
seed_opt = gr.Number(label="Seed (0=Random)", value=0, precision=0)
with gr.Row():
top_k_opt = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="SLA Top-K")
attn_opt = gr.Radio(["sagesla", "sla", "original"], label="Attention", value="sagesla")
sigma_opt = gr.Number(label="Sigma Max", value=80)
norm_opt = gr.Checkbox(label="Original Norms", value=False)
frames_opt = gr.Slider(1, 120, value=77, step=1, label="Frames")
history_gal = gr.Gallery(value=sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), reverse=True), columns=6, height="auto")
# --- 4. Logic Bindings ---
t2v_m.change(fn=sync_path, inputs=t2v_m, outputs=t2v_dit)
t2v_args = [gr.State("T2V"), t2v_p, t2v_m, t2v_dit, gr.State(""), gr.State(""), gr.State(""), t2v_res, t2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, sigma_opt, norm_opt, gr.State(False), gr.State(False)]
t2v_btn.click(run_gen, t2v_args, [t2v_out, history_gal, t2v_stat, t2v_time, console_out], show_progress="hidden")
i2v_args = [i2v_img, i2v_p, gr.State("Wan2.2-A14B"), gr.State(""), i2v_high, i2v_low, i2v_img, i2v_res, i2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, gr.State(200), norm_opt, i2v_adapt, i2v_ode]
i2v_btn.click(run_gen, i2v_args, [i2v_out, history_gal, i2v_stat_2, i2v_time_2, console_out], show_progress="hidden")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Default(), allowed_paths=[OUTPUT_DIR])

Metadata
Metadata
Assignees
Labels
No labels