Skip to content

Add Apple Silicon (MPS) support #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions scripts/image_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def find_bounding_box(gray_image):
x, y, w, h = cv2.boundingRect(max_contour)
return x, y, w, h

def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1):
def load_image(img_path: str, bg_color: np.ndarray = None, rmbg_net=None, padding_ratio: float = 0.1, device: str = "cuda"):
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None:
return f"invalid image path {img_path}"
Expand Down Expand Up @@ -72,13 +72,13 @@ def rmbg(image: torch.Tensor) -> torch.Tensor:
else:
return f"invalid image: channels {num_channels}"

rgb_image_gpu = torch.from_numpy(rgb_image).cuda().float().permute(2, 0, 1) / 255.
rgb_image_gpu = torch.from_numpy(rgb_image).to(device).float().permute(2, 0, 1) / 255.
if alpha is None:
resize_transform = transforms.Resize((384, 384), antialias=True)
rgb_image_resized = resize_transform(rgb_image_gpu)
normalize_image = rgb_image_resized * 2 - 1

mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda()
mean_color = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
resize_transform = transforms.Resize((1024, 1024), antialias=True)
rgb_image_resized = resize_transform(rgb_image_gpu)
max_value = rgb_image_resized.flatten().max()
Expand All @@ -105,7 +105,7 @@ def rmbg(image: torch.Tensor) -> torch.Tensor:
cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
alpha = cleaned_alpha * 255
alpha_gpu = torch.from_numpy(cleaned_alpha).cuda().float().unsqueeze(0)
alpha_gpu = torch.from_numpy(cleaned_alpha).to(device).float().unsqueeze(0)
x, y, w, h = find_bounding_box(alpha)

# If alpha is provided, the bounds of all foreground are used
Expand All @@ -125,7 +125,7 @@ def rmbg(image: torch.Tensor) -> torch.Tensor:
raise ValueError(f"input image too small")

bg_gray = bg_color[0]
bg_color = torch.from_numpy(bg_color).float().cuda().repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
bg_color = torch.from_numpy(bg_color).float().to(device).repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
padding_size = [0] * 6
if w > h:
Expand All @@ -140,9 +140,9 @@ def rmbg(image: torch.Tensor) -> torch.Tensor:

return padded_tensor

def prepare_image(image_path, bg_color, rmbg_net=None):
def prepare_image(image_path: str, bg_color: np.ndarray, rmbg_net=None, device: str = "cuda"):
if os.path.isfile(image_path):
img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net)
img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net, device=device)
img_np = img_tensor.permute(1,2,0).cpu().numpy()
img_pil = Image.fromarray((img_np*255).astype(np.uint8))

Expand Down
29 changes: 25 additions & 4 deletions scripts/inference_triposg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,24 @@ def run_triposg(
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
faces: int = -1,
device: str = "cuda",
use_flash_decoder: bool = True,
) -> trimesh.Scene:

img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net, device=device)

effective_use_flash_decoder = use_flash_decoder
if device == 'mps' and use_flash_decoder:
try:
import diso
print("Note: Using flash_decoder on MPS. If 'diso' library is not fully compatible, issues might occur or performance might vary.")
except ImportError:
print("Warning: 'diso' library not found. 'flash_decoder' cannot be used. Falling back to hierarchical_extract_geometry.")
effective_use_flash_decoder = False

outputs = pipe(
image=img_pil,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
generator=torch.Generator(device=device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).samples[0]
Expand Down Expand Up @@ -66,8 +77,15 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
return mesh

if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
if torch.backends.mps.is_available():
device = "mps"
dtype = torch.float32
elif torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32

parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
Expand All @@ -76,6 +94,7 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
parser.add_argument("--num-inference-steps", type=int, default=50)
parser.add_argument("--guidance-scale", type=float, default=7.0)
parser.add_argument("--faces", type=int, default=-1)
parser.add_argument("--use-flash-decoder", action=argparse.BooleanOptionalAction, default=True)
args = parser.parse_args()

# download pretrained weights
Expand All @@ -100,5 +119,7 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
faces=args.faces,
device=device,
use_flash_decoder=args.use_flash_decoder,
).export(args.output_path)
print(f"Mesh saved to {args.output_path}")
12 changes: 9 additions & 3 deletions scripts/inference_triposg_scribble.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_triposg_scribble(
outputs = pipe(
image=img_pil,
prompt=prompt,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
generator=torch.Generator(device=pipe.device if hasattr(pipe, 'device') else "cpu").manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=0, # this is a CFG-distilled model
attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence},
Expand All @@ -44,8 +44,14 @@ def run_triposg_scribble(


if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

dtype = torch.float16 if device != "cpu" else torch.float32

parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
Expand Down
13 changes: 10 additions & 3 deletions scripts/inference_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@ def load_surface(data_path, num_pc=204800):
ind = rng.choice(surface.shape[0], num_pc, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)

return surface


if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

dtype = torch.float16 if device != "cpu" else torch.float32

parser = argparse.ArgumentParser()
parser.add_argument("--surface-input", type=str, required=True)
args = parser.parse_args()
Expand Down
40 changes: 29 additions & 11 deletions triposg/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import scipy.ndimage
from skimage import measure
from einops import repeat
from diso import DiffDMC
import torch.nn.functional as F

from triposg.utils.typing import *

try:
from diso import DiffDMC
except ImportError:
DiffDMC = None

def generate_dense_grid_points_gpu(bbox_min: torch.Tensor,
bbox_max: torch.Tensor,
octree_depth: int,
Expand Down Expand Up @@ -98,7 +102,7 @@ def find_candidates_band(occupancy_grid: torch.Tensor, band_threshold: float, n_
return core_mesh_coords

def expand_edge_region_fast(edge_coords, grid_size):
expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=torch.float16, requires_grad=False)
expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device=edge_coords.device, dtype=torch.float16, requires_grad=False)
expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
if grid_size < 512:
kernel_size = 5
Expand Down Expand Up @@ -186,7 +190,10 @@ def hierarchical_extract_geometry(geometric_func: Callable,
# breakpoint()
high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
grid_logits = high_res_occupancy
torch.cuda.empty_cache()
if device.type == 'cuda':
torch.cuda.empty_cache()
elif device.type == 'mps':
torch.mps.empty_cache()
mesh_v_f = []
try:
print("final grids shape = ", grid_logits.shape)
Expand All @@ -195,7 +202,10 @@ def hierarchical_extract_geometry(geometric_func: Callable,
mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
except Exception as e:
print(e)
torch.cuda.empty_cache()
if device.type == 'cuda':
torch.cuda.empty_cache()
elif device.type == 'mps':
torch.mps.empty_cache()
mesh_v_f = (None, None)

return [mesh_v_f]
Expand Down Expand Up @@ -463,17 +473,25 @@ def flash_extract_geometry(
grid_logits = grid_logits[0]
try:
print("final grids shape = ", grid_logits.shape)
dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
sdf = -grid_logits / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
vertices = vertices.detach().cpu().numpy()
faces = faces.detach().cpu().numpy()[:, ::-1]
if grid_logits.device.type == 'mps':
print("Warning: DiffDMC (diso library) in flash_extract_geometry might not be compatible with MPS. Using skimage.measure.marching_cubes on CPU as a fallback for this specific call if DiffDMC fails or is unavailable.")
grid_logits_cpu = grid_logits.float().cpu().numpy()
vertices, faces, _, _ = measure.marching_cubes(grid_logits_cpu, mc_level, method="lewiner")
else:
dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
sdf = -grid_logits / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
vertices = vertices.detach().cpu().numpy()
faces = faces.detach().cpu().numpy()[:, ::-1]
vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
except Exception as e:
print(e)
torch.cuda.empty_cache()
if latents.device.type == 'cuda':
torch.cuda.empty_cache()
elif latents.device.type == 'mps':
torch.mps.empty_cache()
mesh_v_f = (None, None)

return [mesh_v_f]
2 changes: 1 addition & 1 deletion triposg/models/autoencoders/autoencoder_kl_triposg.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def query_geometry(
):
logits = model_fn(queries, sample)
if grad:
with torch.autocast(device_type="cuda", dtype=torch.float32):
with torch.autocast(device_type=queries.device.type, dtype=torch.float32, enabled=queries.device.type != 'cpu'):
if self.grad_type == "numerical":
interval = self.grad_interval
grad_value = []
Expand Down
17 changes: 14 additions & 3 deletions triposg/pipelines/pipeline_triposg.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,22 @@ def __call__(


# 7. decoder mesh
if not use_flash_decoder:
effective_use_flash_decoder = use_flash_decoder
if self.device.type == 'mps' and use_flash_decoder:
try:
import diso # type: ignore
# If diso imports, we assume it might work, but it's experimental on MPS.
# A more robust check would involve testing a small diso operation.
logger.warn("Using flash_decoder on MPS. The 'diso' library's compatibility with MPS is not fully guaranteed. If issues arise, consider setting use_flash_decoder=False.")
except ImportError:
logger.warn("'diso' library not found. 'flash_decoder' cannot be used. Falling back to hierarchical_extract_geometry.")
effective_use_flash_decoder = False

if not effective_use_flash_decoder:
geometric_func = lambda x: self.vae.decode(latents, sampled_points=x).sample
output = hierarchical_extract_geometry(
geometric_func,
device,
self.device,
bounds=bounds,
dense_octree_depth=dense_octree_depth,
hierarchical_octree_depth=hierarchical_octree_depth,
Expand All @@ -312,7 +323,7 @@ def __call__(
bounds=bounds,
octree_depth=flash_octree_depth,
)
meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output]
meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output if mesh_v_f[0] is not None and mesh_v_f[1] is not None]

# Offload all models
self.maybe_free_model_hooks()
Expand Down
6 changes: 3 additions & 3 deletions triposg/pipelines/pipeline_triposg_scribble.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __call__(
dense_octree_depth: int = 8,
hierarchical_octree_depth: int = 9,
flash_octree_depth: int = 9,
use_flash_decoder: bool = True,
use_flash_decoder: bool = False, # Defaulting to False due to boundary problems and for MPS compatibility (diso library)
return_dict: bool = True,
):
self._guidance_scale = guidance_scale
Expand Down Expand Up @@ -250,7 +250,7 @@ def __call__(
num_tokens,
num_channels_latents,
image_embeds.dtype,
device,
self.device,
generator,
latents,
)
Expand Down Expand Up @@ -327,7 +327,7 @@ def __call__(
bounds=bounds,
octree_depth=flash_octree_depth,
)
meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output]
meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output if mesh_v_f[0] is not None and mesh_v_f[1] is not None]

# Offload all models
self.maybe_free_model_hooks()
Expand Down