Open
Description
Gradio example presents below:
app.py
import gradio as gr
from gradio_litmodel3d import LitModel3D
import torch
import trimesh
from cube3d.inference.engine import Engine, EngineFast
import os
import shutil
from typing import *
import numpy as np
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
config_path = "cube3d/configs/open_model.yaml"
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def text_to_3d(
prompt: str,
resolution_base: float,
top_k: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
mesh_v_f = engine_fast.t2s([prompt], use_kv_cache=True, resolution_base=resolution_base, top_k=top_k)
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
glb_path = os.path.join(user_dir, 'sample.glb')
_ = trimesh.Trimesh(vertices=vertices, faces=faces).export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab(label="Prompts", id=0) as single_image_input_tab:
text_prompt = gr.Text(label="Text Prompt")
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, interactive=True)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Sparse Structure Generation")
with gr.Row():
resolution_base = gr.Slider(0.0, 10.0, label="Resolution Base", value=8.0, step=0.1, interactive=True)
top_k = gr.Slider(1, 10, label="Top K", value=5, step=1, interactive=True)
extract_glb_btn = gr.Button("Extract GLB", interactive=True)
with gr.Column():
model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
extract_glb_btn.click(
text_to_3d,
inputs=[text_prompt, resolution_base, top_k],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=80)