diff --git a/nodes/nodes_inference.py b/nodes/nodes_inference.py index aeae0764..568afea0 100644 --- a/nodes/nodes_inference.py +++ b/nodes/nodes_inference.py @@ -106,6 +106,9 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params # Normalize with ImageNet stats (manual, no torchvision dependency) normalized_images = imagenet_normalize(images_pt) + # Free images_pt early — we'll reconstruct RGB output from original `images` later + del images_pt + # Prepare for model: add view dimension [B, N, 3, H, W] where N=1 normalized_images = normalized_images.unsqueeze(1) @@ -123,12 +126,15 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params else: logger.warning("Model does not support camera conditioning. Camera params ignored.") - pbar = ProgressBar(B) - depth_out = [] - conf_out = [] - sky_out = [] - ray_origin_out = [] - ray_dir_out = [] + # Pre-allocate contiguous output tensors to avoid memory fragmentation + # from thousands of small tensor allocations in the loop + depth_out = torch.zeros(B, 1, model_H, model_W) + conf_out = torch.zeros(B, 1, model_H, model_W) + sky_out = torch.zeros(B, 1, model_H, model_W) + # Ray tensors are lazily allocated after first frame since model may + # output rays at a different resolution than the input + ray_origin_out = None + ray_dir_out = None extrinsics_list = [] intrinsics_list = [] gaussians_list = [] @@ -138,6 +144,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params if infer_gs: logger.info("Model supports 3D Gaussians - will output raw Gaussians") + pbar = ProgressBar(B) for i in range(B): img = normalized_images[i:i+1].to(device, dtype=dtype) @@ -198,9 +205,10 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params else: conf = torch.ones_like(conf) - depth_out.append(depth_processed.cpu()) - conf_out.append(conf.cpu()) - sky_out.append(sky.cpu()) + # Write directly into pre-allocated tensors (squeeze batch dim from model output) + depth_out[i] = depth_processed.squeeze(0).cpu() + conf_out[i] = conf.squeeze(0).cpu() + sky_out[i] = sky.squeeze(0).cpu() # Extract ray maps (if available) ray = None @@ -211,13 +219,21 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params if ray is not None and torch.is_tensor(ray): ray = ray.squeeze(0).squeeze(0) # [6, H, W] - ray_origin = ray[:3] - ray_dir = ray[3:6] - ray_origin_out.append(ray_origin.cpu()) - ray_dir_out.append(ray_dir.cpu()) + ray_origin = ray[:3].cpu() + ray_dir = ray[3:6].cpu() + # Lazily allocate on first frame with actual ray dimensions + if ray_origin_out is None: + ray_H, ray_W = ray_origin.shape[1], ray_origin.shape[2] + ray_origin_out = torch.zeros(B, 3, ray_H, ray_W) + ray_dir_out = torch.zeros(B, 3, ray_H, ray_W) + ray_origin_out[i] = ray_origin + ray_dir_out[i] = ray_dir else: - ray_origin_out.append(torch.zeros(3, depth.shape[-2], depth.shape[-1])) - ray_dir_out.append(torch.zeros(3, depth.shape[-2], depth.shape[-1])) + # Lazily allocate with depth dimensions as fallback + if ray_origin_out is None: + d_H, d_W = depth.shape[-2], depth.shape[-1] + ray_origin_out = torch.zeros(B, 3, d_H, d_W) + ray_dir_out = torch.zeros(B, 3, d_H, d_W) # Extract camera parameters (if available) extr = None @@ -256,23 +272,36 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params pbar.update(1) + # Free normalized_images now that the loop is done + del normalized_images + # Process outputs based on normalization mode normalize_depth_output = (normalization_mode != "Raw") depth_final = process_tensor_to_image(depth_out, orig_H, orig_W, normalize_output=normalize_depth_output, skip_resize=keep_model_size) + del depth_out conf_final = process_tensor_to_image(conf_out, orig_H, orig_W, normalize_output=True, skip_resize=keep_model_size) + del conf_out sky_final = process_tensor_to_mask(sky_out, orig_H, orig_W, skip_resize=keep_model_size) + del sky_out + # Fallback if rays were never allocated (no frames processed) + if ray_origin_out is None: + ray_origin_out = torch.zeros(B, 3, model_H, model_W) + ray_dir_out = torch.zeros(B, 3, model_H, model_W) ray_origin_final = cls._process_ray_to_image(ray_origin_out, orig_H, orig_W, normalize=True, skip_resize=keep_model_size) + del ray_origin_out ray_dir_final = cls._process_ray_to_image(ray_dir_out, orig_H, orig_W, normalize=True, skip_resize=keep_model_size) + del ray_dir_out - # Process resized RGB image to match depth output dimensions - rgb_resized = images_pt.permute(0, 2, 3, 1).float().cpu() + # Reconstruct RGB output from original images input (already on CPU) + # instead of keeping images_pt alive for the entire inference loop + rgb_resized = images.float() # [B, H, W, C] already on CPU if not keep_model_size: final_H = (orig_H // 2) * 2 final_W = (orig_W // 2) * 2 @@ -282,13 +311,20 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params size=(final_H, final_W), mode="bilinear" ).permute(0, 2, 3, 1) - rgb_resized = torch.clamp(rgb_resized, 0, 1) + else: + # Resize to model dimensions when keeping model size + if rgb_resized.shape[1] != model_H or rgb_resized.shape[2] != model_W: + rgb_resized = F.interpolate( + rgb_resized.permute(0, 3, 1, 2), + size=(model_H, model_W), + mode="bilinear" + ).permute(0, 2, 3, 1) + rgb_resized.clamp_(0, 1) # Scale intrinsics if we resized back to original dimensions if not keep_model_size: final_H = (orig_H // 2) * 2 final_W = (orig_W // 2) * 2 - model_H, model_W = images_pt.shape[2], images_pt.shape[3] if final_H != model_H or final_W != model_W: scale_h = final_H / model_H @@ -311,7 +347,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params if extrinsics_list and extrinsics_list[0] is not None: extrinsics_tensor = torch.stack([e.squeeze() for e in extrinsics_list if e is not None], dim=0) else: - extrinsics_tensor = torch.eye(4).unsqueeze(0).expand(len(depth_out), -1, -1) + extrinsics_tensor = torch.eye(4).unsqueeze(0).expand(B, -1, -1) if intrinsics_list and intrinsics_list[0] is not None: # Convert 3x3 intrinsics to 4x4 homogeneous (compatible with Sharp) @@ -327,7 +363,7 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params intr_tensors.append(k) intrinsics_tensor = torch.stack(intr_tensors, dim=0) else: - intrinsics_tensor = torch.eye(4).unsqueeze(0).expand(len(depth_out), -1, -1) + intrinsics_tensor = torch.eye(4).unsqueeze(0).expand(B, -1, -1) # Save Gaussians to PLY file if available (Giant model only) gaussian_ply_path = "" @@ -353,9 +389,16 @@ def execute(cls, da3_model, images, normalization_mode="V2-Style", camera_params extrinsics_str, intrinsics_str, sky_final, extrinsics_tensor, intrinsics_tensor, gaussian_ply_path) @staticmethod - def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize=False): - """Convert list of ray tensors to ComfyUI IMAGE format.""" - out = torch.cat([r.unsqueeze(0) for r in ray_list], dim=0) + def _process_ray_to_image(ray_input, orig_H, orig_W, normalize=True, skip_resize=False): + """Convert ray tensors to ComfyUI IMAGE format. + + Args: + ray_input: Pre-allocated tensor [B, 3, H, W] or list of tensors [3, H, W] + """ + if isinstance(ray_input, list): + out = torch.cat([r.unsqueeze(0) for r in ray_input], dim=0) + else: + out = ray_input if normalize: for i in range(out.shape[0]): @@ -365,7 +408,7 @@ def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize= if ray_max > ray_min: out[i] = (ray_batch - ray_min) / (ray_max - ray_min) else: - out[i] = torch.zeros_like(ray_batch) + out[i].zero_() out = out.permute(0, 2, 3, 1).float() @@ -381,9 +424,8 @@ def _process_ray_to_image(ray_list, orig_H, orig_W, normalize=True, skip_resize= ).permute(0, 2, 3, 1) if normalize: - return torch.clamp(out, 0, 1) - else: - return out + out.clamp_(0, 1) + return out diff --git a/nodes/utils.py b/nodes/utils.py index fdaec4da..df49608b 100644 --- a/nodes/utils.py +++ b/nodes/utils.py @@ -94,11 +94,11 @@ def check_model_capabilities(model): } -def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False, skip_resize=False): - """Convert list of depth/conf tensors to ComfyUI IMAGE format. +def process_tensor_to_image(tensor_input, orig_H, orig_W, normalize_output=False, skip_resize=False): + """Convert depth/conf tensors to ComfyUI IMAGE format. Args: - tensor_list: List of tensors with shape [1, H, W] or [H, W] + tensor_input: Pre-allocated tensor [B, 1, H, W] or list of tensors with shape [1, H, W] or [H, W] orig_H: Original image height orig_W: Original image width normalize_output: If True, clamp output to 0-1 range @@ -107,8 +107,10 @@ def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False, Returns: Tensor with shape [B, H, W, 3] in ComfyUI IMAGE format """ - # Concatenate all tensors - out = torch.cat(tensor_list, dim=0) # [B, 1, H, W] or [B, H, W] + if isinstance(tensor_input, list): + out = torch.cat(tensor_input, dim=0) + else: + out = tensor_input # Ensure 4D: [B, 1, H, W] if out.dim() == 3: @@ -131,15 +133,15 @@ def process_tensor_to_image(tensor_list, orig_H, orig_W, normalize_output=False, ).permute(0, 2, 3, 1) if normalize_output: - return torch.clamp(out, 0, 1) + out.clamp_(0, 1) return out -def process_tensor_to_mask(tensor_list, orig_H, orig_W, skip_resize=False): - """Convert list of tensors to ComfyUI MASK format. +def process_tensor_to_mask(tensor_input, orig_H, orig_W, skip_resize=False): + """Convert tensors to ComfyUI MASK format. Args: - tensor_list: List of tensors with shape [1, H, W] or [H, W] + tensor_input: Pre-allocated tensor [B, 1, H, W] or list of tensors with shape [1, H, W] or [H, W] orig_H: Original image height orig_W: Original image width skip_resize: If True, keep model's native output size instead of resizing back @@ -147,8 +149,10 @@ def process_tensor_to_mask(tensor_list, orig_H, orig_W, skip_resize=False): Returns: Tensor with shape [B, H, W] in ComfyUI MASK format """ - # Concatenate all tensors - out = torch.cat(tensor_list, dim=0) # [B, 1, H, W] or [B, H, W] + if isinstance(tensor_input, list): + out = torch.cat(tensor_input, dim=0) + else: + out = tensor_input # Ensure 3D: [B, H, W] if out.dim() == 4: @@ -168,7 +172,8 @@ def process_tensor_to_mask(tensor_list, orig_H, orig_W, skip_resize=False): mode="bilinear" ).squeeze(1) # Back to [B, H, W] - return torch.clamp(out, 0, 1) + out.clamp_(0, 1) + return out def resize_to_patch_multiple(images_pt, patch_size=DEFAULT_PATCH_SIZE, method="resize"):