Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 70 additions & 28 deletions nodes/nodes_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
# Normalize with ImageNet stats (manual, no torchvision dependency)
normalized_images = imagenet_normalize(images_pt)

# Free images_pt early — we'll reconstruct RGB output from original `images` later
del images_pt

# Prepare for model: add view dimension [B, N, 3, H, W] where N=1
normalized_images = normalized_images.unsqueeze(1)

Expand All @@ -123,12 +126,15 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
else:
logger.warning("Model does not support camera conditioning. Camera params ignored.")

pbar = ProgressBar(B)
depth_out = []
conf_out = []
sky_out = []
ray_origin_out = []
ray_dir_out = []
# Pre-allocate contiguous output tensors to avoid memory fragmentation
# from thousands of small tensor allocations in the loop
depth_out = torch.zeros(B, 1, model_H, model_W)
conf_out = torch.zeros(B, 1, model_H, model_W)
sky_out = torch.zeros(B, 1, model_H, model_W)
# Ray tensors are lazily allocated after first frame since model may
# output rays at a different resolution than the input
ray_origin_out = None
ray_dir_out = None
extrinsics_list = []
intrinsics_list = []
gaussians_list = []
Expand All @@ -138,6 +144,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
if infer_gs:
logger.info("Model supports 3D Gaussians - will output raw Gaussians")

pbar = ProgressBar(B)
for i in range(B):
img = normalized_images[i:i+1].to(device, dtype=dtype)

Expand Down Expand Up @@ -198,9 +205,10 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
else:
conf = torch.ones_like(conf)

depth_out.append(depth_processed.cpu())
conf_out.append(conf.cpu())
sky_out.append(sky.cpu())
# Write directly into pre-allocated tensors (squeeze batch dim from model output)
depth_out[i] = depth_processed.squeeze(0).cpu()
conf_out[i] = conf.squeeze(0).cpu()
sky_out[i] = sky.squeeze(0).cpu()

# Extract ray maps (if available)
ray = None
Expand All @@ -211,13 +219,21 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params

if ray is not None and torch.is_tensor(ray):
ray = ray.squeeze(0).squeeze(0) # [6, H, W]
ray_origin = ray[:3]
ray_dir = ray[3:6]
ray_origin_out.append(ray_origin.cpu())
ray_dir_out.append(ray_dir.cpu())
ray_origin = ray[:3].cpu()
ray_dir = ray[3:6].cpu()
# Lazily allocate on first frame with actual ray dimensions
if ray_origin_out is None:
ray_H, ray_W = ray_origin.shape[1], ray_origin.shape[2]
ray_origin_out = torch.zeros(B, 3, ray_H, ray_W)
ray_dir_out = torch.zeros(B, 3, ray_H, ray_W)
ray_origin_out[i] = ray_origin
ray_dir_out[i] = ray_dir
else:
ray_origin_out.append(torch.zeros(3, depth.shape[-2], depth.shape[-1]))
ray_dir_out.append(torch.zeros(3, depth.shape[-2], depth.shape[-1]))
# Lazily allocate with depth dimensions as fallback
if ray_origin_out is None:
d_H, d_W = depth.shape[-2], depth.shape[-1]
ray_origin_out = torch.zeros(B, 3, d_H, d_W)
ray_dir_out = torch.zeros(B, 3, d_H, d_W)

# Extract camera parameters (if available)
extr = None
Expand Down Expand Up @@ -256,23 +272,36 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params

pbar.update(1)

# Free normalized_images now that the loop is done
del normalized_images

# Process outputs based on normalization mode
normalize_depth_output = (normalization_mode != "Raw")

depth_final = process_tensor_to_image(depth_out, orig_H, orig_W,
normalize_output=normalize_depth_output,
skip_resize=keep_model_size)
del depth_out
conf_final = process_tensor_to_image(conf_out, orig_H, orig_W,
normalize_output=True,
skip_resize=keep_model_size)
del conf_out
sky_final = process_tensor_to_mask(sky_out, orig_H, orig_W, skip_resize=keep_model_size)
del sky_out
# Fallback if rays were never allocated (no frames processed)
if ray_origin_out is None:
ray_origin_out = torch.zeros(B, 3, model_H, model_W)
ray_dir_out = torch.zeros(B, 3, model_H, model_W)
ray_origin_final = cls._process_ray_to_image(ray_origin_out, orig_H, orig_W,
normalize=True, skip_resize=keep_model_size)
del ray_origin_out
ray_dir_final = cls._process_ray_to_image(ray_dir_out, orig_H, orig_W,
normalize=True, skip_resize=keep_model_size)
del ray_dir_out

# Process resized RGB image to match depth output dimensions
rgb_resized = images_pt.permute(0, 2, 3, 1).float().cpu()
# Reconstruct RGB output from original images input (already on CPU)
# instead of keeping images_pt alive for the entire inference loop
rgb_resized = images.float() # [B, H, W, C] already on CPU
if not keep_model_size:
final_H = (orig_H // 2) * 2
final_W = (orig_W // 2) * 2
Expand All @@ -282,13 +311,20 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
size=(final_H, final_W),
mode="bilinear"
).permute(0, 2, 3, 1)
rgb_resized = torch.clamp(rgb_resized, 0, 1)
else:
# Resize to model dimensions when keeping model size
if rgb_resized.shape[1] != model_H or rgb_resized.shape[2] != model_W:
rgb_resized = F.interpolate(
rgb_resized.permute(0, 3, 1, 2),
size=(model_H, model_W),
mode="bilinear"
).permute(0, 2, 3, 1)
rgb_resized.clamp_(0, 1)

# Scale intrinsics if we resized back to original dimensions
if not keep_model_size:
final_H = (orig_H // 2) * 2
final_W = (orig_W // 2) * 2
model_H, model_W = images_pt.shape[2], images_pt.shape[3]

if final_H != model_H or final_W != model_W:
scale_h = final_H / model_H
Expand All @@ -311,7 +347,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
if extrinsics_list and extrinsics_list[0] is not None:
extrinsics_tensor = torch.stack([e.squeeze() for e in extrinsics_list if e is not None], dim=0)
else:
extrinsics_tensor = torch.eye(4).unsqueeze(0).expand(len(depth_out), -1, -1)
extrinsics_tensor = torch.eye(4).unsqueeze(0).expand(B, -1, -1)

if intrinsics_list and intrinsics_list[0] is not None:
# Convert 3x3 intrinsics to 4x4 homogeneous (compatible with Sharp)
Expand All @@ -327,7 +363,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
intr_tensors.append(k)
intrinsics_tensor = torch.stack(intr_tensors, dim=0)
else:
intrinsics_tensor = torch.eye(4).unsqueeze(0).expand(len(depth_out), -1, -1)
intrinsics_tensor = torch.eye(4).unsqueeze(0).expand(B, -1, -1)

# Save Gaussians to PLY file if available (Giant model only)
gaussian_ply_path = ""
Expand All @@ -353,9 +389,16 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params
extrinsics_str, intrinsics_str, sky_final, extrinsics_tensor, intrinsics_tensor, gaussian_ply_path)

@staticmethod
def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize=False):
"""Convert list of ray tensors to ComfyUI IMAGE format."""
out = torch.cat([r.unsqueeze(0) for r in ray_list], dim=0)
def _process_ray_to_image(ray_input, orig_H, orig_W, normalize=True, skip_resize=False):
"""Convert ray tensors to ComfyUI IMAGE format.

Args:
ray_input: Pre-allocated tensor [B, 3, H, W] or list of tensors [3, H, W]
"""
if isinstance(ray_input, list):
out = torch.cat([r.unsqueeze(0) for r in ray_input], dim=0)
else:
out = ray_input

if normalize:
for i in range(out.shape[0]):
Expand All @@ -365,7 +408,7 @@ def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize=
if ray_max > ray_min:
out[i] = (ray_batch - ray_min) / (ray_max - ray_min)
else:
out[i] = torch.zeros_like(ray_batch)
out[i].zero_()

out = out.permute(0, 2, 3, 1).float()

Expand All @@ -381,9 +424,8 @@ def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize=
).permute(0, 2, 3, 1)

if normalize:
return torch.clamp(out, 0, 1)
else:
return out
out.clamp_(0, 1)
return out



29 changes: 17 additions & 12 deletions nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def check_model_capabilities(model):
}


def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False, skip_resize=False):
"""Convert list of depth/conf tensors to ComfyUI IMAGE format.
def process_tensor_to_image(tensor_input, orig_H, orig_W, normalize_output=False, skip_resize=False):
"""Convert depth/conf tensors to ComfyUI IMAGE format.

Args:
tensor_list: List of tensors with shape [1, H, W] or [H, W]
tensor_input: Pre-allocated tensor [B, 1, H, W] or list of tensors with shape [1, H, W] or [H, W]
orig_H: Original image height
orig_W: Original image width
normalize_output: If True, clamp output to 0-1 range
Expand All @@ -107,8 +107,10 @@ def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False,
Returns:
Tensor with shape [B, H, W, 3] in ComfyUI IMAGE format
"""
# Concatenate all tensors
out = torch.cat(tensor_list, dim=0) # [B, 1, H, W] or [B, H, W]
if isinstance(tensor_input, list):
out = torch.cat(tensor_input, dim=0)
else:
out = tensor_input

# Ensure 4D: [B, 1, H, W]
if out.dim() == 3:
Expand All @@ -131,24 +133,26 @@ def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False,
).permute(0, 2, 3, 1)

if normalize_output:
return torch.clamp(out, 0, 1)
out.clamp_(0, 1)
return out


def process_tensor_to_mask(tensor_list, orig_H, orig_W, skip_resize=False):
"""Convert list of tensors to ComfyUI MASK format.
def process_tensor_to_mask(tensor_input, orig_H, orig_W, skip_resize=False):
"""Convert tensors to ComfyUI MASK format.

Args:
tensor_list: List of tensors with shape [1, H, W] or [H, W]
tensor_input: Pre-allocated tensor [B, 1, H, W] or list of tensors with shape [1, H, W] or [H, W]
orig_H: Original image height
orig_W: Original image width
skip_resize: If True, keep model's native output size instead of resizing back

Returns:
Tensor with shape [B, H, W] in ComfyUI MASK format
"""
# Concatenate all tensors
out = torch.cat(tensor_list, dim=0) # [B, 1, H, W] or [B, H, W]
if isinstance(tensor_input, list):
out = torch.cat(tensor_input, dim=0)
else:
out = tensor_input

# Ensure 3D: [B, H, W]
if out.dim() == 4:
Expand All @@ -168,7 +172,8 @@ def process_tensor_to_mask(tensor_list, orig_H, orig_W, skip_resize=False):
mode="bilinear"
).squeeze(1) # Back to [B, H, W]

return torch.clamp(out, 0, 1)
out.clamp_(0, 1)
return out


def resize_to_patch_multiple(images_pt, patch_size=DEFAULT_PATCH_SIZE, method="resize"):
Expand Down