Skip to content

Integrate with huggingface #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 142 additions & 38 deletions cube3d/inference/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast, CLIPTextConfig
from huggingface_hub import PyTorchModelHubMixin, HfApi
from typing import Union, Optional
from omegaconf import OmegaConf, DictConfig

from cube3d.inference.logits_postprocesses import process_logits
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
Expand All @@ -9,13 +13,32 @@
from cube3d.model.transformers.cache import Cache


class Engine:
class Engine(
nn.Module,
PyTorchModelHubMixin,
coders={
DictConfig: (
lambda x: OmegaConf.to_container(x),
lambda x: OmegaConf.create(x),
),
CLIPTextConfig: (
lambda x: x.to_diff_dict(),
lambda x: CLIPTextConfig(**x),
),
},
library_name="cube",
repo_url="https://github.com/Roblox/cube",
):
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
device: torch.device,
config_path: Union[str] = None,
gpt_ckpt_path: Optional[str] = None,
shape_ckpt_path: Optional[str] = None,
device: torch.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
cfg: DictConfig = None,
clip_config: CLIPTextConfig = None,
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Expand All @@ -35,26 +58,31 @@ def __init__(
min_id (int): Minimum ID for the shape model codes.
max_id (int): Maximum ID for the shape model codes.
"""

self.cfg = load_config(config_path)
super().__init__()
if config_path is not None:
self.cfg = load_config(config_path)
else:
self.cfg = cfg
self.device = device

self.gpt_model = DualStreamRoformer(
parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
)
load_model_weights(
self.gpt_model,
gpt_ckpt_path,
)
if gpt_ckpt_path is not None:
load_model_weights(
self.gpt_model,
gpt_ckpt_path,
)
self.gpt_model = self.gpt_model.eval().to(self.device)

self.shape_model = OneDAutoEncoder(
parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
)
load_model_weights(
self.shape_model,
shape_ckpt_path,
)
if shape_ckpt_path is not None:
load_model_weights(
self.shape_model,
shape_ckpt_path,
)
self.shape_model = self.shape_model.eval().to(self.device)

# copy vq codebook to gpt
Expand All @@ -63,11 +91,15 @@ def __init__(
codebook = self.gpt_model.shape_proj(codebook).detach()
self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook

self.text_model = CLIPTextModelWithProjection.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path,
force_download=False,
device_map=self.device,
).eval()
if clip_config is not None:
self.text_model = CLIPTextModelWithProjection(clip_config)
self.text_model.to(self.device).eval()
else:
self.text_model = CLIPTextModelWithProjection.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path,
force_download=False,
device_map=self.device,
).eval()
self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path
)
Expand Down Expand Up @@ -194,7 +226,7 @@ def run_gpt(
embed.device,
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
for i in tqdm(range(self.max_new_tokens), desc="generating"):
curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
logits = self.gpt_model(
embed_buffer,
Expand Down Expand Up @@ -276,7 +308,7 @@ def t2s(
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
Returns:
mesh_v_f: The generated 3D mesh vertices and faces.
Expand All @@ -286,14 +318,86 @@ def t2s(
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
return mesh_v_f

def push_to_hub(
self,
repo_id,
*,
config=None,
commit_message="Push model using huggingface_hub.",
private=None,
token=None,
branch=None,
create_pr=None,
allow_patterns=None,
ignore_patterns=None,
delete_patterns=None,
model_card_kwargs=None,
):
api = HfApi(token=token)
repo_id = api.create_repo(
repo_id=repo_id, private=private, exist_ok=True
).repo_id
# set repo_id as an attribute if we're pushing to the hub
self._hub_mixin_repo_id = repo_id
out = super().push_to_hub(
repo_id,
config=config,
commit_message=commit_message,
private=private,
token=token,
branch=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
model_card_kwargs=model_card_kwargs,
)
delattr(self, "_hub_mixin_repo_id")
return out

def save_pretrained(
self,
save_directory,
*,
config=None,
repo_id=None,
push_to_hub=False,
model_card_kwargs=None,
**push_to_hub_kwargs,
):
# update the text_model_pretrained_model_name_or_path parameter when pushing or saving
if hasattr(self, "_hub_mixin_repo_id"):
self._hub_mixin_config["cfg"][
"text_model_pretrained_model_name_or_path"
] = self._hub_mixin_repo_id
else:
self._hub_mixin_config["cfg"][
"text_model_pretrained_model_name_or_path"
] = save_directory
return super().save_pretrained(
save_directory,
config=config,
repo_id=repo_id,
push_to_hub=push_to_hub,
model_card_kwargs=model_card_kwargs,
**push_to_hub_kwargs,
)

def _save_pretrained(self, save_directory):
self.text_tokenizer.save_pretrained(save_directory)
return super()._save_pretrained(save_directory)


class EngineFast(Engine):
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
device: torch.device,
config_path: Union[DictConfig, str],
gpt_ckpt_path: Optional[str] = None,
shape_ckpt_path: Optional[str] = None,
device: torch.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
clip_config: CLIPTextConfig = None,
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Expand All @@ -304,11 +408,13 @@ def __init__(
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
"""

assert (
device.type == "cuda"
), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
assert device.type == "cuda", (
"EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
)

super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
super().__init__(
config_path, gpt_ckpt_path, shape_ckpt_path, device, clip_config
)

# CUDA Graph params
self.graph = torch.cuda.CUDAGraph()
Expand Down Expand Up @@ -428,11 +534,11 @@ def _set_curr_pos_id(self, pos: int):
)

def run_gpt(
self,
prompts: list[str],
use_kv_cache: bool,
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None
top_p: float = None,
):
"""
Runs the GPT model to generate text based on the provided prompts.
Expand Down Expand Up @@ -479,9 +585,7 @@ def run_gpt(
next_embed = next_embed.repeat(2, 1, 1)
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))

for i in tqdm(
range(1, self.max_new_tokens), desc=f"generating"
):
for i in tqdm(range(1, self.max_new_tokens), desc="generating"):
self._set_curr_pos_id(i)
self.graph.replay()

Expand Down