diff --git a/.gitignore b/.gitignore index 64ab246..9f22e20 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ src/backend/models/* src/backend/models/*.pt src/backend/models/*.onnx src/backend/models/midas_cache/ +src/backend/checkpoints/ diff --git a/Makefile b/Makefile index b054375..d926114 100644 --- a/Makefile +++ b/Makefile @@ -255,3 +255,11 @@ download-depth-anything: cd src/backend && uv run python ../../scripts/download_models.py \ --models depth-anything \ --output-dir $(MODELS_DIR) + +download-depth-pro: + @echo "Downloading Depth Pro model..." + @mkdir -p $(MODELS_DIR) + cd src/backend && uv sync --extra inference + cd src/backend && uv run python ../../scripts/download_models.py \ + --models depth-pro \ + --output-dir $(MODELS_DIR) diff --git a/README.md b/README.md index 7377f5b..86a2e46 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ make download-models-onnx make download-yolo make download-midas make download-depth-anything +make download-depth-pro # Export models to ONNX make export-yolo-onnx @@ -89,11 +90,18 @@ To start the analyzer service with ONNX backend: DETECTOR_BACKEND=onnx DEPTH_BACKEND=onnx make run-analyzer-local ``` + To start the analyzer service with Depth Anything V2 backend: ```bash DEPTH_BACKEND=depth_anything_v2 make run-analyzer-local ``` +To start the analyzer service with Apple's ML Depth Pro backend: +```bash +DEPTH_BACKEND=depth_pro make run-analyzer-local +``` +*Note: Depth Pro model weights are approx. 1.8 GB and will be downloaded automatically (or via `make download-depth-pro`).* + Example production usage with custom model type: ```bash # Set model type via environment variable diff --git a/scripts/download_models.py b/scripts/download_models.py index 05d6347..2d16ca8 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -45,6 +45,7 @@ export_yolo_to_onnx, DEFAULT_MIDAS_MODEL, DEFAULT_MIDAS_REPO, + ensure_depth_pro_model_available, ) except ImportError as e: logger.error("Failed to import backend modules: %s", e) @@ -120,8 +121,9 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--models", type=str, + default="yolo,midas", - help="Comma-separated list of models to process (yolo, midas, depth-anything)", + help="Comma-separated list of models to process (yolo, midas, depth-anything, depth-pro)", ) return parser.parse_args() @@ -213,10 +215,19 @@ def main() -> None: ensure_depth_anything_model_available( model_name=da_model, + cache_dir=da_cache ) + # --- Depth Pro Processing --- + if "depth-pro" in models_to_process: + logger.info("\n--- Processing Depth Pro ---") + + dp_cache = config.DEPTH_PRO_CACHE_DIR + ensure_depth_pro_model_available(cache_dir=dp_cache) + + logger.info("\n--- Done ---") logger.info("Models available at: %s", output_dir) if midas_cache_final: diff --git a/src/backend/Dockerfile.analyzer b/src/backend/Dockerfile.analyzer index b3228fa..47ce471 100644 --- a/src/backend/Dockerfile.analyzer +++ b/src/backend/Dockerfile.analyzer @@ -9,10 +9,11 @@ WORKDIR /app COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv -# Install system dependencies for opencv +# Install system dependencies for opencv and git for depth_pro RUN apt-get update && apt-get install -y \ libgl1 \ libglib2.0-0 \ + git \ && rm -rf /var/lib/apt/lists/* COPY pyproject.toml uv.lock ./ diff --git a/src/backend/common/config.py b/src/backend/common/config.py index 604af95..98e3cee 100644 --- a/src/backend/common/config.py +++ b/src/backend/common/config.py @@ -57,6 +57,10 @@ class Config: DEPTH_ANYTHING_CACHE_DIR: Path = Path( os.getenv("DEPTH_ANYTHING_CACHE_DIR", "models/depth_anything_cache") ).resolve() + DEPTH_PRO_MODEL: str = os.getenv("DEPTH_PRO_MODEL", "depth_pro") + DEPTH_PRO_CACHE_DIR: Path = Path( + os.getenv("DEPTH_PRO_CACHE_DIR", "models/depth_pro_cache") + ).resolve() MIDAS_ONNX_MODEL_PATH: Path = Path( os.getenv("MIDAS_ONNX_MODEL_PATH", "models/midas_small.onnx") ).resolve() diff --git a/src/backend/common/core/depth.py b/src/backend/common/core/depth.py index 3416229..317e742 100644 --- a/src/backend/common/core/depth.py +++ b/src/backend/common/core/depth.py @@ -259,9 +259,17 @@ def _predict_depth_map( register_depth_backend("torch", MiDasDepthEstimator) register_depth_backend("onnx", OnnxMiDasDepthEstimator) + try: from common.core.depth_anything import DepthAnythingV2Estimator register_depth_backend("depth_anything_v2", DepthAnythingV2Estimator) except ImportError: pass + +try: + from common.core.depth_pro import DepthProEstimator + + register_depth_backend("depth_pro", DepthProEstimator) +except ImportError as e: + logger.debug(f"Could not register 'depth_pro' backend: {e}") diff --git a/src/backend/common/core/depth_pro.py b/src/backend/common/core/depth_pro.py new file mode 100644 index 0000000..66d12fc --- /dev/null +++ b/src/backend/common/core/depth_pro.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: 2025 robot-visual-perception +# +# SPDX-License-Identifier: MIT +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from PIL import Image + +from common.config import config +from common.core.contracts import DepthEstimator, Detection +from common.core.depth_utils import resize_to_frame + +import depth_pro # type: ignore + +logger = logging.getLogger(__name__) + + +class DepthProEstimator(DepthEstimator): + """Depth estimator backed by Apple's ML Depth Pro.""" + + def __init__( + self, + cache_directory: Optional[Path] = None, + model_name: str = config.DEPTH_PRO_MODEL, + ) -> None: + # depth_pro import is now strict at module level. + + self.region_size = config.REGION_SIZE + self.scale_factor = config.SCALE_FACTOR + self.update_freq = config.UPDATE_FREQ + + self.update_id = -1 + self.last_depths: list[float] = [] + + self.cache_directory = cache_directory or config.DEPTH_PRO_CACHE_DIR + self.model_name = model_name + + logger.info("Loading Depth Pro model...") + + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + + try: + self.model, self.transform = depth_pro.create_model_and_transforms( + device=self.device, aspect_ratio=1.0 + ) + self.model.eval() + except Exception as e: + logger.error(f"Failed to initialize Depth Pro: {e}") + raise + + def estimate_distance_m( + self, frame_rgb: np.ndarray, dets: list[Detection] + ) -> list[float]: + """Estimate distance in meters for each detection based on depth map.""" + self.update_id += 1 + if self.update_id % self.update_freq != 0 and len(self.last_depths) == len( + dets + ): + return self.last_depths + + h, w, _ = frame_rgb.shape + depth_map = self._predict_depth_map(frame_rgb, (h, w)) + + distances = self._distances_from_depth_map(depth_map, dets) + self.last_depths = distances + return distances + + def _predict_depth_map( + self, frame_rgb: np.ndarray, output_shape: tuple[int, int] + ) -> np.ndarray: + image_pil = Image.fromarray(frame_rgb) + image_tensor = self.transform(image_pil) + + with torch.no_grad(): + prediction = self.model.infer(image_tensor, f_px=None) + + depth = prediction["depth"] + + if isinstance(depth, torch.Tensor): + depth = depth.cpu().numpy() + + return resize_to_frame(depth, output_shape) + + def _distances_from_depth_map( + self, + depth_map: np.ndarray, + dets: list[Detection], + ) -> list[float]: + dists = [] + h, w = depth_map.shape[:2] + + for det in dets: + x1 = max(0, int(det.x1)) + y1 = max(0, int(det.y1)) + x2 = min(w, int(det.x2)) + y2 = min(h, int(det.y2)) + + if x2 <= x1 or y2 <= y1: + dists.append(0.0) + continue + + region = depth_map[y1:y2, x1:x2] + + # Use median depth in the box as the object distance + dist_m = float(np.median(region)) + dists.append(dist_m) + + return dists diff --git a/src/backend/common/core/model_downloader.py b/src/backend/common/core/model_downloader.py index d4509ed..7d45cfa 100644 --- a/src/backend/common/core/model_downloader.py +++ b/src/backend/common/core/model_downloader.py @@ -20,6 +20,11 @@ AutoImageProcessor = None # type: ignore AutoModelForDepthEstimation = None # type: ignore +try: + import depth_pro # type: ignore +except ImportError: + depth_pro = None + logger = logging.getLogger(__name__) # Constants @@ -290,3 +295,94 @@ def ensure_depth_anything_model_available( error_msg = f"Failed to load Depth Anything model {model_name}: {e}" logger.error(error_msg) raise RuntimeError(error_msg) from e + + +def ensure_depth_pro_model_available( + cache_dir: Optional[Path] = None, +) -> Path: + """Ensure Depth Pro model is downloaded and cached. + + This function initializes the model once to trigger any internal downloads + or verifications provided by the 'depth_pro' library. + + Args: + cache_dir: Directory to cache the model (if applicable/supported by depth_pro) + + Returns: + Path where the model is expected to be (or just a success confirmation) + """ + if cache_dir is None: + cache_dir = ( + Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + ) # Default guess or config + + # Ensure import + if depth_pro is None: + raise ImportError( + "depth_pro is not available. Please install it via `uv sync --extra inference`" + ) + + try: + # Define expected checkpoint path + # depth_pro expects 'checkpoints/depth_pro.pt' by default relative to CWD? + # actually it looks for config.checkpoint_uri which defaults to that. + # We should set it or place the file there. + + # NOTE: depth_pro implementation detail: + # It usually looks for `checkpoints/depth_pro.pt` in the current working directory. + # We can try to rely on that or see if we can trick it. + # However, it works best if we download it to a known location and maybe symlink or move it, + # OR if we can pass the path to `create_model_and_transforms`. + # Checking depth_pro source (passed context): `load(config.checkpoint_uri)` + # `config` is imported from `depth_pro`. + + # Let's download to our cache dir first. + + cache_dir = Path(str(cache_dir)).resolve() + cache_dir.mkdir(parents=True, exist_ok=True) + + checkpoint_name = "depth_pro.pt" + checkpoint_path = cache_dir / checkpoint_name + + url = "https://ml-site.cdn-apple.com/models/depth-pro/depth_pro.pt" + + if not checkpoint_path.exists(): + logger.info("Downloading Depth Pro weights to %s...", checkpoint_path) + torch.hub.download_url_to_file(url, str(checkpoint_path), progress=True) + else: + logger.info("Depth Pro weights found at %s", checkpoint_path) + + # Now we need to tell depth_pro where the file is. + # Since we can't easily patch the config before import if it's already imported, + # we might need to modify `depth_pro.depth_pro.config.checkpoint_uri`? + # Or just symlink it to ./checkpoints/depth_pro.pt in the run directory? + # + # Let's try to set the config if exposed. + # Based on typical python modules: + # import depth_pro.config as dp_config ? or depth_pro.depth_pro.config? + # + # A safer bet for now (without deep diving into their config struct) is + # to ensure the file exists at `./checkpoints/depth_pro.pt` relative to CWD. + + cwd_checkpoints = Path.cwd() / "checkpoints" + cwd_checkpoints.mkdir(exist_ok=True) + cwd_target = cwd_checkpoints / "depth_pro.pt" + + if not cwd_target.exists(): + # Symlink or copy + try: + cwd_target.symlink_to(checkpoint_path) + logger.info("Symlinked checkpoint to %s", cwd_target) + except OSError: + # Fallback to copy if symlink fails (e.g. windows without privs) + shutil.copy2(checkpoint_path, cwd_target) + logger.info("Copied checkpoint to %s", cwd_target) + + # Now instantiate + depth_pro.create_model_and_transforms() + + logger.info("Depth Pro model is ready.") + return cache_dir + except Exception as e: + logger.error(f"Failed to load Depth Pro model: {e}") + raise RuntimeError(f"Depth Pro initialization failed: {e}") from e diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 59d53d4..4c152b9 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -38,6 +38,7 @@ inference = [ "ultralytics==8.3.58", "timm==1.0.22", "transformers==4.49.0", # for Depth Anything V2 + "depth_pro @ git+https://github.com/apple/ml-depth-pro.git", ] onnx-tools = [ diff --git a/src/backend/uv.lock b/src/backend/uv.lock index 929401c..3091a66 100644 --- a/src/backend/uv.lock +++ b/src/backend/uv.lock @@ -301,6 +301,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] +[[package]] +name = "depth-pro" +version = "0.1" +source = { git = "https://github.com/apple/ml-depth-pro.git#9efe5c1def37a26c5367a71df664b18e1306c708" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pillow-heif" }, + { name = "timm" }, + { name = "torch" }, + { name = "torchvision" }, +] + [[package]] name = "dnspython" version = "2.8.0" @@ -1111,6 +1124,7 @@ dev = [ { name = "ruff" }, ] inference = [ + { name = "depth-pro" }, { name = "httpx" }, { name = "timm" }, { name = "transformers" }, @@ -1136,6 +1150,7 @@ requires-dist = [ { name = "aioice", specifier = "==0.10.1" }, { name = "aiortc", specifier = "==1.14.0" }, { name = "av", specifier = "==16.0.1" }, + { name = "depth-pro", marker = "extra == 'inference'", git = "https://github.com/apple/ml-depth-pro.git" }, { name = "fastapi", specifier = "==0.115.10" }, { name = "httpx", marker = "extra == 'inference'", specifier = "==0.27.2" }, { name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" }, @@ -1220,6 +1235,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/7e/f896623c3c635a90537ac093c6a618ebe1a90d87206e42309cb5d98a1b9e/pillow-12.0.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5", size = 6997850, upload-time = "2025-10-15T18:24:11.495Z" }, ] +[[package]] +name = "pillow-heif" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pillow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/65/77284daf2a8a2849b9040889bd8e1b845e693ed97973a28ba2122b8922ad/pillow_heif-1.1.1.tar.gz", hash = "sha256:f60e8c8a8928556104cec4fff39d43caa1da105625bdb53b11ce3c89d09b6bde", size = 18271952, upload-time = "2025-09-30T16:42:24.485Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/13/dd7908c39ea368abd2d25e4fa3eef97a29cc5446c9ba0d47b1fe13564f79/pillow_heif-1.1.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:57c140368c7ddefa20ecb9b737b4af2d2d5ea0806d1d59be4c525e6a73e6aa72", size = 4696596, upload-time = "2025-09-30T16:41:14.083Z" }, + { url = "https://files.pythonhosted.org/packages/93/a2/dbcbfd4264d19ce5b1776327f25633cc20b2b2840fc602d85341dcfce782/pillow_heif-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0b6ae0b002ecc7873273ded99aaffa567f7806f4bc57ee1eff7ab5fe1f70e5e7", size = 3451082, upload-time = "2025-09-30T16:41:15.597Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7b/2488882acf9756c8d22108e1828232cdd216a3d333a1824cd41eee102632/pillow_heif-1.1.1-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:39ea2fe8878e722bdfaf30a9b711629c3a4b8a0627b70a833f7381cbd3ef8e87", size = 5774792, upload-time = "2025-09-30T16:41:16.771Z" }, + { url = "https://files.pythonhosted.org/packages/97/aa/c048f7e337ef40a86a2501d264a0f430ab8772c35306c907b1e00ddf5099/pillow_heif-1.1.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a01644c3c4bc576437c05e1ece4b89814fc381684f5d7926850e01d6e9b6502", size = 5505379, upload-time = "2025-09-30T16:41:18.252Z" }, + { url = "https://files.pythonhosted.org/packages/a0/6a/28ca3dbfdd1bf2e0aaebcb38b2c375ab76ee647588a8a91200f0a0c3cb5b/pillow_heif-1.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5ebe3b7f707b984c8886f367697531d004967b7d8949a34645c7bc1c6a888fe6", size = 6810454, upload-time = "2025-09-30T16:41:19.618Z" }, + { url = "https://files.pythonhosted.org/packages/64/87/a1909c6c8514b9cc451633e92d9c1088268ea0e913deabd7bb5740a5abe7/pillow_heif-1.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c8c8e1a561877006a5a0b654392e614c879d9e4db89d0786a94fe9f5773bcacb", size = 6432367, upload-time = "2025-09-30T16:41:21.396Z" }, + { url = "https://files.pythonhosted.org/packages/99/48/fa2407203087be5424d514c40a816eee517450e98772c1f00fe846b87a8b/pillow_heif-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:3d296f874bea4dd17bab7309b843a766834d2b5df53c591eaf3f7cdc91a4c1a3", size = 5422266, upload-time = "2025-09-30T16:41:23.107Z" }, + { url = "https://files.pythonhosted.org/packages/17/7e/e7182fd74e911993ac3d4522ce43af439888baff14d8bc75fe9ee5a95580/pillow_heif-1.1.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0ea9c72f5cbe1b35229be883797eb7f113d2e7353dc21a66fd813a33d95a16b3", size = 4685305, upload-time = "2025-09-30T16:42:15.61Z" }, + { url = "https://files.pythonhosted.org/packages/aa/92/181b49961411b89c857cbb984030aa6ab0886a059be574b0f7f402e098cf/pillow_heif-1.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:16f83a7e0ad3aa4209ae592db2842d35faab21b44d269fb3b1145e07ecbecebc", size = 3447553, upload-time = "2025-09-30T16:42:16.96Z" }, + { url = "https://files.pythonhosted.org/packages/d5/c5/2ce061f60d52a1603c0e8634409a480c8ff799379ab0822b8e9c1d9a78bd/pillow_heif-1.1.1-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7fc8273124fe96d83fd6dee9476a5b58b6338cb41ffe97581fc2e8f17c97864c", size = 5733421, upload-time = "2025-09-30T16:42:18.737Z" }, + { url = "https://files.pythonhosted.org/packages/49/9a/9e6cfc339b2de5cb19e7762f89b59e6ef15fb41219f7e382bb2d507245bb/pillow_heif-1.1.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ca64d2e83b28ae7f194640e1c6d5d842de8f061845a4fd700a4ab7efb9df15f9", size = 5461117, upload-time = "2025-09-30T16:42:20.48Z" }, + { url = "https://files.pythonhosted.org/packages/20/c5/1912f3b9220a91ef449a710bce1a3128a633b44d86a17ef58fb376403bfd/pillow_heif-1.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:7520b37f183f5339c9a0dbdd4cae468cc7d7f191fff26fd18d8d96cf69089994", size = 5422656, upload-time = "2025-09-30T16:42:22.39Z" }, +] + [[package]] name = "pluggy" version = "1.6.0"