Skip to content

WebUI requirement #65

Open
Open
@darcyOly999

Description

@darcyOly999

Gradio example presents below:

app.py

import gradio as gr
from gradio_litmodel3d import LitModel3D
import torch
import trimesh
from cube3d.inference.engine import Engine, EngineFast
import os
import shutil
from typing import *


import numpy as np


MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)

config_path = "cube3d/configs/open_model.yaml"
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise
    config_path,
    gpt_ckpt_path,
    shape_ckpt_path,
    device=torch.device("cuda"),
)


def start_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    os.makedirs(user_dir, exist_ok=True)


def end_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    shutil.rmtree(user_dir)

def get_seed(randomize_seed: bool, seed: int) -> int:
    """
    Get the random seed.
    """
    return np.random.randint(0, MAX_SEED) if randomize_seed else seed


def text_to_3d(
    prompt: str,
    resolution_base: float,
    top_k: int,
    req: gr.Request,
) -> Tuple[str, str]:
    """
    Extract a GLB file from the 3D model.

    Args:
        state (dict): The state of the generated 3D model.
        mesh_simplify (float): The mesh simplification factor.
        texture_size (int): The texture resolution.

    Returns:
        str: The path to the extracted GLB file.
    """
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    mesh_v_f = engine_fast.t2s([prompt], use_kv_cache=True, resolution_base=resolution_base, top_k=top_k)
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    glb_path = os.path.join(user_dir, 'sample.glb')
    _ = trimesh.Trimesh(vertices=vertices, faces=faces).export(glb_path)
    torch.cuda.empty_cache()
    return glb_path, glb_path



with gr.Blocks(delete_cache=(600, 600)) as demo:
    with gr.Row():
        with gr.Column():
            with gr.Tabs() as input_tabs:
                with gr.Tab(label="Prompts", id=0) as single_image_input_tab:
                    text_prompt = gr.Text(label="Text Prompt")

            with gr.Accordion(label="Generation Settings", open=False):
                seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, interactive=True)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                gr.Markdown("Sparse Structure Generation")
                with gr.Row():
                    resolution_base = gr.Slider(0.0, 10.0, label="Resolution Base", value=8.0, step=0.1, interactive=True)
                    top_k = gr.Slider(1, 10, label="Top K", value=5, step=1, interactive=True)

            extract_glb_btn = gr.Button("Extract GLB", interactive=True)

        with gr.Column():
            model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)

            with gr.Row():
                download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

    output_buf = gr.State()

    # Handlers
    demo.load(start_session)
    demo.unload(end_session)

    extract_glb_btn.click(
        text_to_3d,
        inputs=[text_prompt, resolution_base, top_k],
        outputs=[model_output, download_glb],
    ).then(
        lambda: gr.Button(interactive=True),
        outputs=[download_glb],
    )

    model_output.clear(
        lambda: gr.Button(interactive=False),
        outputs=[download_glb],
    )


# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=80)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions