diff --git a/label_studio_ml/examples/yolo_sam2_tracker/.dockerignore b/label_studio_ml/examples/yolo_sam2_tracker/.dockerignore new file mode 100644 index 000000000..894795019 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/.dockerignore @@ -0,0 +1,18 @@ +# Exclude everything +_wsgi.py + +# Include Dockerfile and docker-compose for reference (optional, decide based on your use case) +!Dockerfile +!docker-compose.yml + +# Include Python application files +!*.py + +# Include requirements files +!requirements*.txt + +# Include script +!*.sh + +# Exclude specific requirements if necessary +# requirements-test.txt (Uncomment if you decide to exclude this) diff --git a/label_studio_ml/examples/yolo_sam2_tracker/Dockerfile b/label_studio_ml/examples/yolo_sam2_tracker/Dockerfile new file mode 100644 index 000000000..5b844bf68 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/Dockerfile @@ -0,0 +1,82 @@ +FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime + +ARG DEBIAN_FRONTEND=noninteractive +ARG TEST_ENV + +WORKDIR /app + +# Update Conda +RUN conda update conda -y + +# Install system dependencies +RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \ + --mount=type=cache,target="/var/lib/apt/lists",sharing=locked \ + apt-get -y update \ + && apt-get install -y git wget g++ freeglut3-dev build-essential \ + libx11-dev libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev \ + libfreeimage-dev ffmpeg libsm6 libxext6 libffi-dev python3-dev \ + python3-pip gcc + +# Environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_CACHE_DIR=/.cache \ + PORT=9090 \ + WORKERS=2 \ + THREADS=4 \ + CUDA_HOME=/usr/local/cuda \ + TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0" \ + SEGMENT_ANYTHING_2_REPO_PATH=/segment-anything-2 \ + PYTHONPATH=/app + +# Install Python dependencies +COPY requirements-base.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements-base.txt + +COPY requirements.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements.txt + +# Install segment-anything-2 +RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/sam2.git +WORKDIR /sam2 +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -e . +RUN cd checkpoints && ./download_ckpts.sh + +# Clone and install label-studio-ml-backend +WORKDIR /app +RUN git clone https://github.com/HumanSignal/label-studio-ml-backend.git \ + && cd label-studio-ml-backend/ \ + && pip install -e . + +# Return to app working directory +WORKDIR /app + +# Install test dependencies (optional) +COPY requirements-test.txt . +RUN if [ "$TEST_ENV" = "true" ]; then \ + pip install -r requirements-test.txt; \ + else \ + echo "Skipping test dependencies installation"; \ + fi + +# Download YOLO models (only if the image exists) +RUN /bin/sh -c 'if [ -f /app/tests/car.jpg ]; then \ + yolo predict model=/app/models/yolov8m.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-cls.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-seg.pt source=/app/tests/car.jpg; \ + else \ + echo "Image not found, skipping YOLO model tests"; \ + fi' + +# Copy app files +COPY . ./ + +# Ensure the script has executable permissions +RUN chmod +x /app/start.sh + +# Default command +CMD ["/app/start.sh"] diff --git a/label_studio_ml/examples/yolo_sam2_tracker/README.md b/label_studio_ml/examples/yolo_sam2_tracker/README.md new file mode 100644 index 000000000..0bcfbca74 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/README.md @@ -0,0 +1,58 @@ +This guide describes the simplest way to start using ML backend with Label Studio. + +## Running with Docker (Recommended) + +1. Start Machine Learning backend on `http://localhost:9090` with prebuilt image: + +```bash +docker-compose up +``` + +2. Validate that backend is running + +```bash +$ curl http://localhost:9090/ +{"status":"UP"} +``` + +3. Connect to the backend from Label Studio running on the same host: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. + + +## Building from source (Advanced) + +To build the ML backend from source, you have to clone the repository and build the Docker image: + +```bash +docker-compose build +``` + +## Running without Docker (Advanced) + +To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip: + +```bash +python -m venv ml-backend +source ml-backend/bin/activate +pip install -r requirements.txt +``` + +Then you can start the ML backend: + +```bash +label-studio-ml start ./dir_with_your_model +``` + +# Configuration +Parameters can be set in `docker-compose.yml` before running the container. + + +The following common parameters are available: +- `BASIC_AUTH_USER` - specify the basic auth user for the model server +- `BASIC_AUTH_PASS` - specify the basic auth password for the model server +- `LOG_LEVEL` - set the log level for the model server +- `WORKERS` - specify the number of workers for the model server +- `THREADS` - specify the number of threads for the model server + +# Customization + +The ML backend can be customized by adding your own models and logic inside the `./dir_with_your_model` directory. \ No newline at end of file diff --git a/label_studio_ml/examples/yolo_sam2_tracker/__init__.py b/label_studio_ml/examples/yolo_sam2_tracker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/label_studio_ml/examples/yolo_sam2_tracker/_wsgi.py b/label_studio_ml/examples/yolo_sam2_tracker/_wsgi.py new file mode 100644 index 000000000..20115a61f --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/_wsgi.py @@ -0,0 +1,124 @@ +import argparse +import json +import logging.config +import os + +# Set a default log level if LOG_LEVEL is not defined +log_level = os.getenv("LOG_LEVEL", "INFO") + +logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, # Prevent overriding existing loggers + "formatters": { + "standard": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": log_level, + "stream": "ext://sys.stdout", + "formatter": "standard" + } + }, + "root": { + "level": log_level, + "handlers": [ + "console" + ], + "propagate": True + } +}) + +from label_studio_ml.api import init_app +from model import YoloSamMultiObjectTracking + + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Label studio') + parser.add_argument( + '-p', '--port', dest='port', type=int, default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument( + '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='), + help='Additional LabelStudioMLBase model initialization kwargs') + parser.add_argument( + '-d', '--debug', dest='debug', action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=log_level, + help='Logging level') + parser.add_argument( + '--model-dir', dest='model_dir', default=os.path.dirname(__file__), + help='Directory where models are stored (relative to the project directory)') + parser.add_argument( + '--check', dest='check', action='store_true', + help='Validate model instance before launching server') + parser.add_argument('--basic-auth-user', + default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None), + help='Basic auth user') + + parser.add_argument('--basic-auth-pass', + default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None), + help='Basic auth pass') + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'false': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + YoloSamMultiObjectTracking.__name__ + '" instance creation..') + model = YoloSamMultiObjectTracking(**kwargs) + + app = init_app(model_class=YoloSamMultiObjectTracking, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app(model_class=YoloSamMultiObjectTracking) diff --git a/label_studio_ml/examples/yolo_sam2_tracker/control_models/__init__.py b/label_studio_ml/examples/yolo_sam2_tracker/control_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/label_studio_ml/examples/yolo_sam2_tracker/control_models/base.py b/label_studio_ml/examples/yolo_sam2_tracker/control_models/base.py new file mode 100644 index 000000000..fbc3e1cc4 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/control_models/base.py @@ -0,0 +1,201 @@ +import os +import logging + +from pydantic import BaseModel +from typing import Optional, List, Dict, ClassVar +from ultralytics import YOLO + +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.utils import DATA_UNDEFINED_NAME +from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path +from label_studio_sdk.label_interface.control_tags import ControlTag +from label_studio_sdk.label_interface import LabelInterface + + +# use matplotlib plots for debug +DEBUG_PLOT = os.getenv("DEBUG_PLOT", "false").lower() in ["1", "true"] +MODEL_SCORE_THRESHOLD = float(os.getenv("MODEL_SCORE_THRESHOLD", 0.5)) +DEFAULT_MODEL_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") +MODEL_ROOT = os.getenv("MODEL_ROOT", DEFAULT_MODEL_ROOT) +os.makedirs(MODEL_ROOT, exist_ok=True) +# if true, allow to use custom model path from the control tag in the labeling config +ALLOW_CUSTOM_MODEL_PATH = os.getenv("ALLOW_CUSTOM_MODEL_PATH", "true").lower() in [ + "1", + "true", +] + +# Global cache for YOLO models +_model_cache = {} +logger = logging.getLogger(__name__) + + +def get_bool(attr, attr_name, default="false"): + return attr.get(attr_name, default).lower() in ["1", "true", "yes"] + + +class ControlModel(BaseModel): + """ + Represents a control tag in Label Studio, which is associated with a specific type of labeling task + and is used to generate predictions using a YOLO model. + + Attributes: + type (str): Type of the control, e.g., RectangleLabels, Choices, etc. + control (ControlTag): The actual control element from the Label Studio configuration. + from_name (str): The name of the control tag, used to link the control to the data. + to_name (str): The name of the data field that this control is associated with. + value (str): The value name from the object that this control operates on, e.g., an image or text field. + model (object): The model instance (e.g., YOLO) used to generate predictions for this control. + model_path (str): Path to the YOLO model file. + model_score_threshold (float): Threshold for prediction scores; predictions below this value will be ignored. + label_map (Optional[Dict[str, str]]): A mapping of model labels to Label Studio labels. + """ + + type: ClassVar[str] + control: ControlTag + from_name: str + to_name: str + value: str + model: YOLO + model_path: ClassVar[str] + model_score_threshold: float = 0.5 + label_map: Optional[Dict[str, str]] = {} + label_studio_ml_backend: LabelStudioMLBase + project_id: Optional[str] = None + + def __init__(self, **data): + super().__init__(**data) + + @classmethod + def is_control_matched(cls, control) -> bool: + """Check if the control tag matches the model type. + Args: + control (ControlTag): The control tag from the Label Studio Interface. + """ + raise NotImplementedError("This method should be overridden in derived classes") + + @staticmethod + def get_from_name_for_label_map( + label_interface: LabelInterface, target_name: str + ) -> str: + """Get the 'from_name' attribute for the label map building.""" + return target_name + + @classmethod + def create(cls, mlbackend: LabelStudioMLBase, control: ControlTag): + """Factory method to create an instance of a specific control model class. + Args: + mlbackend (LabelStudioMLBase): The ML backend instance. + control (ControlTag): The control tag from the Label Studio Interface. + """ + from_name = control.name + to_name = control.to_name[0] + value = control.objects[0].value_name + + # if skip is true, don't process this control + if get_bool(control.attr, "model_skip", "false"): + logger.info( + f"Skipping control tag '{control.tag}' with name '{from_name}', model_skip=true found" + ) + return None + # read threshold attribute from the control tag, e.g.: + model_score_threshold = float( + control.attr.get("model_score_threshold") + or control.attr.get( + "score_threshold" + ) # not recommended option, use `model_score_threshold` + or MODEL_SCORE_THRESHOLD + ) + # read `model_path` attribute from the control tag + model_path = ( + ALLOW_CUSTOM_MODEL_PATH and control.attr.get("model_path") + ) or cls.model_path + + model = cls.get_cached_model(model_path) + model_names = model.names.values() # class names from the model + # from_name for label mapping can be differed from control.name (e.g. VideoRectangle) + label_map_from_name = cls.get_from_name_for_label_map( + mlbackend.label_interface, from_name + ) + label_map = mlbackend.build_label_map(label_map_from_name, model_names) + + return cls( + control=control, + from_name=from_name, + to_name=to_name, + value=value, + model=model, + model_score_threshold=model_score_threshold, + label_map=label_map, + label_studio_ml_backend=mlbackend, + project_id=mlbackend.project_id, + ) + + @classmethod + def load_yolo_model(cls, filename) -> YOLO: + """Load YOLO model from the file.""" + path = os.path.join(MODEL_ROOT, filename) + logger.info(f"Loading yolo model: {path}") + model = YOLO(path) + logger.info(f"Model {path} names:\n{model.names}") + return model + + @classmethod + def get_cached_model(cls, path: str) -> YOLO: + if path not in _model_cache: + _model_cache[path] = cls.load_yolo_model(path) + return _model_cache[path] + + def debug_plot(self, image): + if not DEBUG_PLOT: + return + + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 10)) + plt.imshow(image[..., ::-1]) + plt.axis("off") + plt.title(self.type) + plt.show() + + def predict_regions(self, path) -> List[Dict]: + """Predict regions in the image using the YOLO model. + Args: + path (str): Path to the file with media + """ + raise NotImplementedError("This method should be overridden in derived classes") + + def fit(self, event, data, **kwargs): + """Fit the model.""" + logger.warning("The fit method is not implemented for this control model") + return False + + def get_path(self, task): + task_path = task["data"].get(self.value) or task["data"].get( + DATA_UNDEFINED_NAME + ) + if task_path is None: + raise ValueError( + f"Can't load path using key '{self.value}' from task {task}" + ) + if not isinstance(task_path, str): + raise ValueError(f"Path should be a string, but got {task_path}") + + # try path as local file or try to load it from Label Studio instance/download via http + path = ( + task_path + if os.path.exists(task_path) + else get_local_path(task_path, task_id=task.get("id")) + ) + logger.debug(f"load_image: {task_path} => {path}") + return path + + def __str__(self): + """Return a string with full representation of the control tag.""" + return ( + f"{self.type} from_name={self.from_name}, " + f"label_map={self.label_map}, model_score_threshold={self.model_score_threshold}" + ) + + class Config: + arbitrary_types_allowed = True + protected_namespaces = ("__.*__", "_.*") # Excludes 'model_' diff --git a/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle.py b/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle.py new file mode 100644 index 000000000..00290198b --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle.py @@ -0,0 +1,227 @@ +import os +import cv2 +import logging +import yaml +import hashlib + +from collections import defaultdict +from control_models.base import ControlModel, MODEL_ROOT +from label_studio_sdk.label_interface.control_tags import ControlTag +from typing import List, Dict, Union + + +logger = logging.getLogger(__name__) + + +class VideoRectangleModel(ControlModel): + """ + Class representing a RectangleLabels (bounding boxes) control tag for YOLO model. + """ + + type = "VideoRectangle" + model_path = "yolov8n.pt" + + @classmethod + def is_control_matched(cls, control: ControlTag) -> bool: + # check object tag type + if control.objects[0].tag != "Video": + return False + # check control type VideoRectangle + return control.tag == cls.type + + @staticmethod + def get_from_name_for_label_map(label_interface, target_name) -> str: + """VideoRectangle doesn't have labels inside, and we should find a connected Labels tag + and return its name as a source for the label map. + """ + target: ControlTag = label_interface.get_control(target_name) + if not target: + raise ValueError(f'Control tag with name "{target_name}" not found') + + for connected in label_interface.controls: + if connected.tag == "Labels" and connected.to_name == target.to_name: + return connected.name + + logger.error("VideoRectangle detected, but no connected 'Labels' tag found") + + @staticmethod + def get_video_duration(path): + if not os.path.exists(path): + raise ValueError(f"Video file not found: {path}") + video = cv2.VideoCapture(path) + fps = video.get(cv2.CAP_PROP_FPS) + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps + logger.info( + f"Video duration: {duration} seconds, {frame_count} frames, {fps} fps" + ) + return frame_count, duration + + def predict_regions(self, path) -> List[Dict]: + # bounding box parameters + # https://docs.ultralytics.com/modes/track/?h=track#tracking-arguments + conf = float(self.control.attr.get("model_conf", 0.25)) + iou = float(self.control.attr.get("model_iou", 0.70)) + + # tracking parameters + # https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers + tracker_name = self.control.attr.get( + "model_tracker", "botsort" + ) # or 'bytetrack' + original = f"{MODEL_ROOT}/{tracker_name}.yaml" + tmp_yaml = self.update_tracker_params(original, prefix=tracker_name + "_") + tracker = tmp_yaml if tmp_yaml else original + + # run model track + try: + results = self.model.track( + path, conf=conf, iou=iou, tracker=tracker, stream=True + ) + finally: + # clean temporary file + if tmp_yaml and os.path.exists(tmp_yaml): + os.remove(tmp_yaml) + + # convert model results to label studio regions + return self.create_video_rectangles(results, path) + + def create_video_rectangles(self, results, path): + """Create regions of video rectangles from the yolo tracker results""" + frames_count, duration = self.get_video_duration(path) + model_names = self.model.names + logger.debug( + f"create_video_rectangles: {self.from_name}, {frames_count} frames" + ) + + tracks = defaultdict(list) + track_labels = dict() + frame = -1 + for result in results: + frame += 1 + data = result.boxes + if not data.is_track: + continue + + for i, track_id in enumerate(data.id.tolist()): + score = float(data.conf[i]) + x, y, w, h = data.xywhn[i].tolist() + # get label + model_label = model_names[int(data.cls[i])] + if model_label not in self.label_map: + continue + output_label = self.label_map[model_label] + track_labels[track_id] = output_label + + box = { + "frame": frame + 1, + "enabled": True, + "rotation": 0, + "x": (x - w / 2) * 100, + "y": (y - h / 2) * 100, + "width": w * 100, + "height": h * 100, + "time": (frame + 1) * (duration / frames_count), + "score": score, + } + tracks[track_id].append(box) + + regions = [] + for track_id in tracks: + sequence = tracks[track_id] + sequence = self.process_lifespans_enabled(sequence) + + label = track_labels[track_id] + region = { + "from_name": self.from_name, + "to_name": self.to_name, + "type": "videorectangle", + "value": { + "framesCount": frames_count, + "duration": duration, + "sequence": sequence, + "labels": [label], + }, + "score": max([frame_info["score"] for frame_info in sequence]), + "origin": "manual", + } + regions.append(region) + + return regions + + @staticmethod + def process_lifespans_enabled(sequence: List[Dict]) -> List[Dict]: + """This function detects gaps in the sequence of bboxes + and disables lifespan line for the gaps assigning "enabled": False + to the last bboxes in the whole span sequence. + """ + prev = None + for i, box in enumerate(sequence): + if prev is None: + prev = sequence[i] + continue + if box["frame"] - prev["frame"] > 1: + sequence[i - 1]["enabled"] = False + prev = sequence[i] + + # the last frame enabled is false to turn off lifespan line + sequence[-1]["enabled"] = False + return sequence + + @staticmethod + def generate_hash_filename(extension=".yaml"): + """Store yaml configs as temporary files just for one model.track() run""" + hash_name = hashlib.sha256(os.urandom(16)).hexdigest() + os.makedirs(f"{MODEL_ROOT}/tmp/", exist_ok=True) + return f"{MODEL_ROOT}/tmp/{hash_name}{extension}" + + def update_tracker_params(self, yaml_path: str, prefix: str) -> Union[str, None]: + """Update tracker parameters in the yaml file with the attributes from the ControlTag, + e.g. + or + Args: + yaml_path: Path to the original yaml file. + prefix: Prefix for attributes of control tag to extract + Returns: + The file path for new yaml with updated parameters + """ + # check if there are any custom parameters in the labeling config + for attr_name, attr_value in self.control.attr.items(): + if attr_name.startswith(prefix): + break + else: + # no custom parameters, exit + return None + + # Load the original yaml file + with open(yaml_path, "r") as file: + config = yaml.safe_load(file) + + # Extract parameters with prefix from ControlTag + for attr_name, attr_value in self.control.attr.items(): + if attr_name.startswith(prefix): + # Remove prefix and update the corresponding yaml key + key = attr_name[len(prefix) :] + + # Convert value to the appropriate type (bool, int, float, etc.) + if isinstance(config[key], bool): + attr_value = attr_value.lower() == "true" + elif isinstance(config[key], int): + attr_value = int(attr_value) + elif isinstance(config[key], float): + attr_value = float(attr_value) + + config[key] = attr_value + + # Generate a new filename with a random hash + new_yaml_filename = self.generate_hash_filename() + + # Save the updated config to a new yaml file + with open(new_yaml_filename, "w") as file: + yaml.dump(config, file) + + # Return the new filename + return new_yaml_filename + + +# pre-load and cache default model at startup +VideoRectangleModel.get_cached_model(VideoRectangleModel.model_path) diff --git a/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle_with_yolo_sam2_tracker.py b/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle_with_yolo_sam2_tracker.py new file mode 100644 index 000000000..df7209500 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/control_models/video_rectangle_with_yolo_sam2_tracker.py @@ -0,0 +1,1044 @@ +import base64 +import glob +import logging +import os +import pathlib +import sys +import tempfile +from collections import defaultdict +# YOLO + SAM2 related imports +from dataclasses import dataclass, field +from typing import List, Dict, Set +from typing import Literal, cast +from typing import Optional + +import av +import cv2 +import numpy as np +import torch +from control_models.video_rectangle import VideoRectangleModel +from pycocotools import mask as coco_mask +from ultralytics import YOLO + +from label_studio_sdk.label_interface.control_tags import ControlTag + +from label_studio_ml.examples.yolo.control_models.base import get_bool +from label_studio_ml.response import ModelResponse + +# read the environment variables and set the paths just before importing the sam2 module +SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'sam2') +sys.path.append(SEGMENT_ANYTHING_2_REPO_PATH) +from sam2.build_sam import build_sam2_video_predictor + + +# Global cache for YOLO models +_model_cache = {} +logger = logging.getLogger(__name__) + +DEVICE = os.getenv('DEVICE', 'cuda') +SAM2_MODEL_CONFIG = os.getenv('MODEL_CONFIG', './configs/sam2.1/sam2.1_hiera_l.yaml') +SAM2_MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_large.pt') +PROMPT_TYPE = cast(Literal["box", "point"], os.getenv('PROMPT_TYPE', 'box')) +ANNOTATION_WORKAROUND = os.getenv('ANNOTATION_WORKAROUND', False) +DEBUG = os.getenv('DEBUG', False) +LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY', '') +SAM2YOLOBOX_THRESHOLD = float(os.getenv('SAM2YOLOBOX_THRESHOLD', 0.6)) +LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') +VIDEO_FRAME_RATE = int(os.getenv('VIDEO_FRAME_RATE', 24)) + +# Set the log level +logging.basicConfig(level=LOG_LEVEL) + +if DEVICE == 'cuda': + # use bfloat16 for the entire notebook + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + + if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +# build path to the model checkpoint +SAM2_MODEL_CHECKPOINT_PATH = str(pathlib.Path(__file__).parent / SEGMENT_ANYTHING_2_REPO_PATH / "checkpoints" / SAM2_MODEL_CHECKPOINT) +logger.info(f'Model checkpoint: {SAM2_MODEL_CHECKPOINT}') +logger.info(f'Model config: {SAM2_MODEL_CONFIG}') +SAM2_PREDICTOR = build_sam2_video_predictor(SAM2_MODEL_CONFIG, SAM2_MODEL_CHECKPOINT_PATH) + + +# manage cache for inference state +# TODO: make it process-safe and implement cache invalidation +_predictor_state_key = '' +_inference_state = None + +def get_inference_state(video_dir): + """ + Get the inference state for the video directory. If the video directory is different from the previous one, + :param video_dir: + :return: + """ + global _predictor_state_key, _inference_state + if _predictor_state_key != video_dir: + _predictor_state_key = video_dir + _inference_state = SAM2_PREDICTOR.init_state(video_path=video_dir) + return _inference_state + + +class ImageFolderSource: + def __init__(self, folder_path: str, sorting_rule: Optional[callable] = None): + self.folder_path = folder_path + # Supported image extensions + image_extensions = ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff") + self.image_paths = [] + for ext in image_extensions: + self.image_paths.extend(glob.glob(os.path.join(folder_path, ext))) + if not self.image_paths: + raise IOError(f"No images found in folder: {folder_path}") + # Apply sorting rule + # self.image_paths.sort(key=sorting_rule if sorting_rule else lambda x: x) + + self.image_paths = sorted(self.image_paths, key=sorting_rule) + + self.frame_count = len(self.image_paths) + + def get_frame(self, frame_index: int): + if frame_index < 0 or frame_index >= self.frame_count: + raise IndexError("Frame index out of range") + image_path = self.image_paths[frame_index] + frame = cv2.imread(image_path) + if frame is None: + raise IOError(f"Failed to read image at path: {image_path}") + return frame + + def get_frame_count(self) -> int: + return self.frame_count + + def release(self): + pass # No resource to release for image folders + + def get_image_paths(self, start_index: int, end_index: int) -> List[str]: + """ + Returns a list of image paths between start_index (inclusive) and end_index (exclusive). + """ + if start_index < 0 or end_index > self.frame_count or start_index >= end_index: + raise ValueError("Invalid index range") + return self.image_paths[start_index:end_index] + + +@dataclass +class Mask: + """ + Represents a mask with an encoded string and shape. + """ + + encoded: str # Encoded string as per pycocotools + shape: List[int] # [height, width] + + def to_json(self): + json_encoded = self.encoded.copy() + json_encoded["counts"] = base64.b64encode(self.encoded["counts"]).decode( + "utf-8" + ) + + return {"encoded": json_encoded, "shape": self.shape} + + @classmethod + def from_json(cls, data): + json_encoded = data["encoded"].copy() + json_encoded["counts"] = base64.b64decode(json_encoded["counts"]) + + return cls(encoded=json_encoded, shape=data["shape"]) + + +@dataclass +class ObjectDetection: + """ + Stores the bounding box and class of an object. + """ + + detection_id: int + # Normalization is based on the image resolution (width, height) + xyxyn: List[ + float + ] # Normalized bounding box coordinates [x_min, y_min, x_max, y_max] + object_class: str # Class of the object, e.g., 'human' or 'vehicle' + + def to_json(self): + return { + "detection_id": self.detection_id, + "xyxyn": self.xyxyn, + "object_class": self.object_class, + } + + @classmethod + def from_json(cls, data): + return cls( + detection_id=data["detection_id"], + xyxyn=data["xyxyn"], + object_class=data["object_class"], + ) + + +@dataclass +class ObjectTracking: + """ + Stores the object tracking information. + """ + + tracking_id: int + start_frame: int + duration_frames: int + masks: Dict[int, Mask] = field(default_factory=dict) # Frame ID to Mask + original_detection_id: Dict[int, Optional[int]] = field( + default_factory=dict + ) # Frame ID to detection ID + + def to_json(self): + return { + "tracking_id": self.tracking_id, + "start_frame": self.start_frame, + "duration_frames": self.duration_frames, + "masks": { + str(frame_id): mask.to_json() for frame_id, mask in self.masks.items() + }, + "original_detection_id": { + str(frame_id): det_id + for frame_id, det_id in self.original_detection_id.items() + }, + } + + @classmethod + def from_json(cls, data): + return cls( + tracking_id=data["tracking_id"], + start_frame=data["start_frame"], + duration_frames=data["duration_frames"], + masks={ + int(frame_id): Mask.from_json(mask_data) + for frame_id, mask_data in data.get("masks", {}).items() + }, + original_detection_id={ + int(frame_id): det_id + for frame_id, det_id in data.get("original_detection_id", {}).items() + }, + ) + + +@dataclass +class SingleVideoAnnotatorState: + """ + State of the video annotator focusing on a single video. + """ + + frame_object_detections: Dict[int, List[ObjectDetection]] = field( + default_factory=dict + ) # Frame ID to list of detections + object_trackings: Dict[int, ObjectTracking] = field( + default_factory=dict + ) # Tracking ID to ObjectTracking + + num_assigned_detections: int = 0 # Number of assigned detections + num_assigned_trackings: int = 0 # Number of assigned trackings + + def to_json(self): + return { + "frame_object_detections": { + str(frame_id): [detection.to_json() for detection in detections] + for frame_id, detections in self.frame_object_detections.items() + }, + "object_trackings": { + str(tracking_id): tracking.to_json() + for tracking_id, tracking in self.object_trackings.items() + }, + "num_assigned_detections": self.num_assigned_detections, + "num_assigned_trackings": self.num_assigned_trackings, + } + + @classmethod + def from_json(cls, data): + frame_object_detections = { + int(frame_id): [ + ObjectDetection.from_json(det_data) for det_data in detections + ] + for frame_id, detections in data.get("frame_object_detections", {}).items() + } + object_trackings = { + int(tracking_id): ObjectTracking.from_json(tracking_data) + for tracking_id, tracking_data in data.get("object_trackings", {}).items() + } + num_assigned_detections = data.get("num_assigned_detections", 0) + num_assigned_trackings = data.get("num_assigned_trackings", 0) + return cls( + frame_object_detections=frame_object_detections, + object_trackings=object_trackings, + num_assigned_detections=num_assigned_detections, + num_assigned_trackings=num_assigned_trackings, + ) + + +class SingleVideoAnnotatorModel: + """ + Model of the video annotator application responsible for storing: + - Detections + - Single-video object tracking (without re-identification) + The model notifies observers when its state changes and exposes an interface for the controller. + """ + + def __init__( + self, + object_classes: Set[str] = {"person"} + ): + """ + Initializes the annotator model for a single video. + + Args: + video_id (str): Identifier for the video being annotated. + video_source_path (str): Path to the video file or image folder. + sorting_rule (callable, optional): Sorting function for image filenames if video_source_path is an image folder. + """ + self.video_source = None + self.object_classes = object_classes + + self.yolo_model = None # Load an official Detect model + + self.state = SingleVideoAnnotatorState() + + # H, W, 3 + self.image_shape = self.get_frame(0).shape[:2] if self.video_source else None + + + # SAM2 related attributes + self.sam2_model_cfg = None + self.sam2_model_checkpoint_path = None + self.sam2_max_frames_to_track = None + self.prompt_type = None + self.annotation_workaround = None + self.sam2_predictor = None + + self._predictor_state_key = '' + self._inference_state = None + + @classmethod + def load_yolo_model(cls, checkpoint) -> YOLO: + """Load YOLO model from the file.""" + logger.info(f"Loading yolo model: {checkpoint}") + model = YOLO(checkpoint) + logger.info(f"Model {checkpoint} names:\n{model.names}") + return model + + @classmethod + def get_cached_model(cls, path: str) -> YOLO: + if path not in _model_cache: + _model_cache[path] = cls.load_yolo_model(path) + return _model_cache[path] + + @staticmethod + def get_video_fps_duration(path, fps): + if not os.path.exists(path): + raise ValueError(f"Video file not found: {path}") + container = av.open(path) + duration = container.duration / av.time_base # Duration in seconds + frame_count = int(duration * fps) + logger.info(f"Video duration: {duration} seconds, {frame_count} frames, {fps} fps") + return frame_count, duration + + def get_inference_state(self, video_dir): + """ + Get the inference state for the video directory. If the video directory is different from the previous one, + :param video_dir: + :return: + """ + if self._predictor_state_key != video_dir: + self._predictor_state_key = video_dir + self._inference_state = SAM2_PREDICTOR.init_state(video_path=video_dir) + return self._inference_state + + def build_sam2_predictor(self, sam2_model_checkpoint_path, sam2_model_cfg, sam2_max_frames_to_track, prompt_type, annotation_workaround): + """ + Update the SAM2 configuration. + :param sam2_model_checkpoint: + :param sam2_model_cfg: + :param sam2_max_frames_to_track: + :param prompt_type: + :param annotation_workaround: + :return: + """ + + self.sam2_model_cfg = sam2_model_cfg + self.sam2_model_checkpoint_path = sam2_model_checkpoint_path + self.sam2_max_frames_to_track = sam2_max_frames_to_track + self.prompt_type = prompt_type + self.annotation_workaround = annotation_workaround + + self.sam2_predictor = build_sam2_video_predictor( + self.sam2_model_cfg, self.sam2_model_checkpoint_path, device="cuda:0" + ) + + return self.sam2_predictor + + def convert_mask_to_bbox(self, mask: Mask): + """ + Function to convert a mask to a bounding box. + Used from Label Studio ML examples. + :param mask: + :return: + """ + # Decode the mask + mask_np = coco_mask.decode(mask.encoded) + + # squeeze + logger.debug(f"Mask shape: {mask_np.shape}") + + y_indices, x_indices = np.where(mask_np == 1) + if len(x_indices) == 0 or len(y_indices) == 0: + return None + + # Find the min and max indices + xmin, xmax = np.min(x_indices), np.max(x_indices) + ymin, ymax = np.min(y_indices), np.max(y_indices) + + # Get mask dimensions + height, width = mask_np.shape + + # Calculate bounding box dimensions + box_width = xmax - xmin + 1 + box_height = ymax - ymin + 1 + + # Normalize and scale to percentage + x_pct = (xmin / width) * 100 + y_pct = (ymin / height) * 100 + width_pct = (box_width / width) * 100 + height_pct = (box_height / height) * 100 + + return { + "x": round(x_pct, 2), + "y": round(y_pct, 2), + "width": round(width_pct, 2), + "height": round(height_pct, 2) + } + + def split_frames(self, video_path, temp_dir, video_fps, start_frame=0, end_frame=100): + """ + Extracts and saves frames from a video file within the specified range. + + This method processes a video file, extracts frames between the provided + start and end frame indexes, and stores the extracted images in the specified + temporary directory. Each frame is saved as a `.jpg` file, and the method + yields the file path and the corresponding frame data for further processing. + + :param video_path: Path to the input video file. + :param temp_dir: Directory to store the extracted frame images temporarily. + :param start_frame: Index of the first frame to extract (inclusive). + :param end_frame: Index of the last frame to extract (exclusive). + :return: Yields a tuple containing the file path and raw frame data for each extracted frame. + """ + logger.debug(f'Opening video file: {video_path}') + + # Get video properties using the static method + frame_count, duration = self.get_video_fps_duration(video_path, fps=video_fps) + + # Open the video using OpenCV + video = cv2.VideoCapture(video_path) + + if not video.isOpened(): + raise ValueError(f"Could not open video file: {video_path}") + + logger.debug(f'Video duration: {duration} seconds, {frame_count} frames, {video_fps} fps') + + frame_count_current = 0 + while frame_count_current < frame_count: + # Calculate the timestamp in seconds for the current frame + timestamp = frame_count_current / video_fps + + # Set the video capture position to the corresponding timestamp in milliseconds + video.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + + # Read the frame at the calculated timestamp + success, frame = video.read() + + if not success: + logger.error(f'Failed to read frame {frame_count_current}, this could be due to an empty frame.') + break + + if frame_count_current < start_frame: + frame_count_current += 1 + continue + + if frame_count_current >= end_frame: + break + + frame_filename = os.path.join(temp_dir, f'{frame_count_current:05d}.jpg') + + if not os.path.exists(frame_filename): + cv2.imwrite(frame_filename, frame) + + logger.debug(f'Frame {frame_count_current}: {frame_filename}') + yield frame_filename, frame + frame_count_current += 1 + + video.release() + + def get_net_predictions_as_regions(self, video_path: str, video_fps: int, context: Optional[Dict] = None): + """ + Extract predictions from the ObjectTracking object and create video rectangle regions. + """ + logger.debug(f"Getting net predictions for video: {video_path}") + sequences = [] + start_frame = 0 + + # Get video metadata + frame_count, duration = self.get_video_fps_duration(video_path, fps=video_fps) + + # Dictionary to store tracking sequences and labels + tracks = defaultdict(list) + track_labels = dict() + + # Iterate over tracking data + for tracking_id, tracking in self.state.object_trackings.items(): + label_map = context.get('label_map', {}) if context else {} + track_labels[tracking_id] = label_map.get(tracking_id, 'Unknown') + for frame_id in range(tracking.start_frame, tracking.start_frame + tracking.duration_frames + 1): + if frame_id in tracking.masks: + mask: Mask = tracking.masks[frame_id] + bbox = self.convert_mask_to_bbox(mask) + if bbox: + box = { + 'frame': frame_id + 1, + 'x': bbox['x'], + 'y': bbox['y'], + 'width': bbox['width'], + 'height': bbox['height'], + 'enabled': True, + 'rotation': 0, + 'time': frame_id / video_fps, + } + tracks[tracking_id].append(box) + + # Process tracks to create regions + regions = [] + for track_id, sequence in tracks.items(): + label = track_labels[track_id] + max_score = max([frame_info.get("score", 1.0) for frame_info in sequence]) # Use 1.0 as a default score + + region = { + "from_name": "box", + "to_name": "video", + "type": "videorectangle", + "value": { + "framesCount": frame_count, + "duration": duration, + "sequence": sequence, + "labels": [label], + }, + "score": max_score, + "origin": "manual", + } + regions.append(region) + + return regions + + + # state modification methods + def add_detection(self, frame_id: int, detection: ObjectDetection): + """ + Adds a detection to a specific frame. + """ + detections = self.state.frame_object_detections.setdefault(frame_id, []) + detections.append(detection) + # self.notify_observers( + # frame_id=frame_id, + # changed="detections" + # ) + + def clear_detections(self, frame_id: int): + """ + Clears all detections from a specific frame. + """ + self.state.frame_object_detections.pop(frame_id, None) + # self.notify_observers( + # frame_id=frame_id, + # changed="detections" + # ) + + def append_mask_to_tracking(self, tracking_id: int, frame_id: int, mask: Mask): + """ + Appends a mask to an existing tracking. + """ + if tracking_id not in self.state.object_trackings: + self.add_tracking( + ObjectTracking( + tracking_id=tracking_id, start_frame=frame_id, duration_frames=0 + ) + ) + + tracking = self.state.object_trackings[tracking_id] + tracking.masks[frame_id] = mask + tracking.duration_frames = max( + tracking.duration_frames, frame_id - tracking.start_frame + ) + + # self.notify_observers( + # frame_id=frame_id, + # tracking_id=tracking_id, + # changed="tracking" + # ) + + def add_tracking(self, tracking: ObjectTracking): + """ + Adds a tracking object. + """ + self.state.object_trackings[tracking.tracking_id] = tracking + + def get_YOLO_detections(self, + conf, + iou, + yolo_model_checkpoint, + video_source_path, + video_fps, + max_frames_to_track, + output_frames_dir="predictions/yolo"): + """ + Run YOLO detection on the entire video, overlay the detections on each frame, and save each frame as an image. + """ + if DEVICE == 'cuda': + self.yolo_model = self.get_cached_model(yolo_model_checkpoint) + else: + logger.error("Only CUDA is supported for YOLO model") + return + + # Open video source + cap = cv2.VideoCapture(video_source_path) + frame_count, duration = self.get_video_fps_duration(video_source_path, fps=video_fps) + + # Get the frame width and height from the video + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + logger.info(f"Video dimensions: width={width}, height={height}") + + # Ensure output directory exists + if not os.path.exists(output_frames_dir): + os.makedirs(output_frames_dir) + + logger.info(f"Reading frames from {video_source_path} at {video_fps} FPS...") + + frame_id = 0 + frames_to_track = min(frame_count, max_frames_to_track) + + while cap.isOpened() and frame_id < frames_to_track: + # Read from a specific timestamp based on the fps and frame_id + timestamp = frame_id / video_fps + cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + ret, frame = cap.read() + logger.debug(f"Detection frame {frame_id}") + if not ret: + break # End of video + + # Run YOLO detection on the current frame, + # Image size must be multiple of max stride 32 + img_size = min(width, height) - (min(width, height) % 32) + detection_result = self.yolo_model.track( + frame, + conf=conf, + iou=iou, + imgsz=img_size, + persist=True, + show=False, + verbose=False, + )[0] + + detected_boxes = detection_result.boxes + det_class_mapping = detection_result.names + + self.clear_detections(frame_id) + + for detection in detected_boxes: + new_detection_id = self.assign_new_detection_id() + detected_class = det_class_mapping[int(detection.cls[0].item())] + + # Check if the detected class is in the object classes + if detected_class not in self.object_classes: + continue + + logger.debug(f"Detected class: {detected_class}, at {detection.xyxyn[0]} for frame {frame_id}") + + # Extract bounding box coordinates (normalize to [0, 1] range, multiply by image size) + x1, y1, x2, y2 = detection.xyxyn[0] # normalized coordinates + x1, y1, x2, y2 = int(x1 * width), int(y1 * height), int(x2 * width), int(y2 * height) + + # Draw bounding box and label on the frame + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green rectangle + cv2.putText(frame, detected_class, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) + + # Add detection to internal structure + self.add_detection( + frame_id, + ObjectDetection( + detection_id=new_detection_id, + xyxyn=detection.xyxyn[0].tolist(), + object_class=detected_class, + ), + ) + + # Save the frame with the overlay as an image + frame_filename = os.path.join(output_frames_dir, f"frame_{frame_id:04d}.jpg") + + try: + # Attempt to write the frame to a file + if not cv2.imwrite(frame_filename, frame): + logger.error(f"Failed to write frame {frame_id} to {frame_filename}") + else: + logger.debug(f"Successfully saved frame {frame_id} to {frame_filename}") + except Exception as e: + logger.error(f"Error writing frame {frame_id} to {frame_filename}: {e}") + + frame_id += 1 + + # Release resources + cap.release() + logger.info(f"Frames with detections saved to {output_frames_dir}") + + def get_sam_tracking_with_yolo_prompts( + self, video_fps: int, frames_to_track: int, sam_batchsize: int = 100, video_source_path: str = None + ): + """ + Processes a video using SAM tracking with YOLO prompts. + + Args: + video_fps (int): Frames per second for the video. + frames_to_track (int): Total number of frames to process. + sam_batchsize (int): Batch size for SAM processing. + video_source_path (str): Path to the source video. + """ + current_frame = 0 + # -1 is done to avoid out of range error as propagation only happens N -1 frames + sam_end_frame = min(self.get_video_fps_duration(video_source_path, fps=video_fps)[0], frames_to_track) - 1 + + logger.info(f"Processing video with SAM tracking for {sam_end_frame} frames...") + + with tempfile.TemporaryDirectory() as temp_img_dir: + # temp_img_dir = '/tmp/frames' # Use persisted directory for debugging + # os.makedirs(temp_img_dir, exist_ok=True) + + frames = list( + self.split_frames( + video_path=video_source_path, temp_dir=temp_img_dir, video_fps=video_fps, start_frame=0, end_frame=sam_end_frame + 1 + ) + ) + + self.video_source = ImageFolderSource(temp_img_dir, sorting_rule=lambda x: x) + self.image_shape = self.get_frame(0).shape[:2] + logger.info(f"Video dimensions: width={self.image_shape[1]}, height={self.image_shape[0]}") + + with torch.autocast("cuda", torch.bfloat16): + while current_frame < sam_end_frame - 1: + logger.debug(f"Processing frames using SAM Tracker...") + while current_frame < sam_end_frame: + logger.debug(f"Processing frame {current_frame}") + + sam_prompts = self.prepare_sam_prompts(current_frame) + if not sam_prompts: + logger.debug(f"No SAM prompts for frame {current_frame}. Skipping.") + current_frame += 1 + continue + + logger.debug(f"Initializing inference state for frame {current_frame}") + inference_state = self.get_inference_state(video_dir=temp_img_dir) + self.sam2_predictor.reset_state(inference_state) + + logger.debug(f"Adding {len(sam_prompts)} prompts to SAM predictor") + for tracking_id, mask_bbox in sam_prompts: + _, out_obj_ids, out_mask_logits = self.sam2_predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=current_frame, + obj_id=tracking_id, + box=mask_bbox, + ) + + for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(inference_state): + for i, out_obj_id in enumerate(out_obj_ids): + mask = (out_mask_logits[i] > 0.0).cpu().numpy()[0] + if mask.any(): + mask_obj = Mask( + encoded=coco_mask.encode(np.asfortranarray(mask)), + shape=mask.shape, + ) + self.append_mask_to_tracking(out_obj_id, out_frame_idx, mask_obj) + + # unexplained_detections = self.get_unexplained_detections_at_frame(out_frame_idx) + # logger.debug(f"Unexplained detections at frame {out_frame_idx}: {len(unexplained_detections)}") + # + # if unexplained_detections: + # logger.warning(f"Stopping propagation at frame {out_frame_idx} due to unexplained detections.") + # break + + current_frame = out_frame_idx + logger.info(f"Updated current frame to {current_frame}") + + def prepare_sam_prompts(self, frame_id: int): + """ + Prepares SAM prompts for a frame. + + Args: + frame_id (int): Frame ID for which to prepare the prompts. + + Returns: + List[str]: List of prompts for SAM. + """ + # For all object tracking, get the masks for the frame. These tracking ids # will be kept in the SAM prompts. + annotated_masks = self.get_annotated_masks_at_frame(frame_id) + unexplained_detections = self.get_unexplained_detections_at_frame(frame_id) + + # prepare the SAM prompts based on existing masks and unexplained detections. The prompts will be in the form of bounding boxes. + sam_prompts = [] + + for tracking_id, mask in annotated_masks.items(): + mask_bbox = coco_mask.toBbox(mask.encoded) + + # NOTE: this bbox is in x, y, w, h format + mask_bbox_xyxy = np.array( + [ + mask_bbox[0], + mask_bbox[1], + mask_bbox[0] + mask_bbox[2], + mask_bbox[1] + mask_bbox[3], + ] + ) + + sam_prompts.append((tracking_id, mask_bbox_xyxy)) + + for detection in unexplained_detections: + detection_bbox = detection.xyxyn * np.array( + [ + self.image_shape[1], + self.image_shape[0], + self.image_shape[1], + self.image_shape[0], + ] + ) + sam_prompts.append((self.assign_new_tracking_id(), detection_bbox)) + + return sam_prompts + + def get_annotated_masks_at_frame(self, frame_id: int): + annotated_masks = {} # tracking_id to mask + for tracking_id, tracking in self.state.object_trackings.items(): + if frame_id in tracking.masks: + mask = tracking.masks[frame_id] + # Add the mask to the SAM prompts + annotated_masks[tracking_id] = mask + + return annotated_masks + + def get_unexplained_detections_at_frame(self, frame_id: int): + """ + This function returns all the bboxes that are not in the SAM propagation. + :param frame_id: + :return: + """ + + annotated_masks = self.get_annotated_masks_at_frame(frame_id) + + # Get the detections for the frame, try to explain them with the masks. If not possible, add the detection to the SAM prompts with a new tracking id. + all_detections = self.state.frame_object_detections.get(frame_id, []) + + # explain the detections with the masks + unexplained_detections = [] + for detection in all_detections: + explained = False + for tracking_id, mask in annotated_masks.items(): + if self.explain_detection_with_mask(detection, mask): + explained = True + break + if not explained: + unexplained_detections.append(detection) + + return unexplained_detections + + def explain_detection_with_mask(self, detection: ObjectDetection, mask: Mask): + """ + Explains a detection with a mask. + + Args: + detection (ObjectDetection): Detection to explain. + mask (Mask): Mask to explain the detection. + + Returns: + bool: True if the detection was explained, False otherwise. + """ + + # a detection is explained by a mask if the bounding of the mask + # and the detection have error less then 0.1 normalized error. + + # get the bounding box of the mask + decoded_mask = coco_mask.decode(mask.encoded) + mask_bbox = coco_mask.toBbox(mask.encoded) # notice! this is x, y, w, h + + detection_xyxy = np.array( + [ + int(detection.xyxyn[0] * self.image_shape[1]), + int(detection.xyxyn[1] * self.image_shape[0]), + int(detection.xyxyn[2] * self.image_shape[1]), + int(detection.xyxyn[3] * self.image_shape[0]), + ] + ) + + mask_in_bbox = np.sum( + decoded_mask[ + detection_xyxy[1] : detection_xyxy[3], + detection_xyxy[0] : detection_xyxy[2], + ] + ) + mask_pixels = np.sum(decoded_mask) + + mask_in_bbox = mask_in_bbox / (mask_pixels + 1e-6) + + return mask_in_bbox > SAM2YOLOBOX_THRESHOLD + + def assign_new_detection_id(self): + """ + Assigns a new detection ID for a new detection. + """ + assigned_id = self.state.num_assigned_detections + + self.state.num_assigned_detections += 1 + + + return assigned_id + + def assign_new_tracking_id(self): + """ + Assigns a new tracking ID for a new tracking. + """ + assigned_id = self.state.num_assigned_trackings + + self.state.num_assigned_trackings += 1 + + + return assigned_id + + def get_frame(self, frame_index: int): + """ + Retrieves a frame from the video source. + + Args: + frame_index (int): Index of the frame to retrieve. + + Returns: + The video frame as an image. + """ + return self.video_source.get_frame(frame_index) + + def get_frame_count(self) -> int: + """ + Returns the total number of frames in the video. + + Returns: + int: Total number of frames. + """ + return self.video_source.get_frame_count() + + def release_video(self): + """ + Releases the video source resources. + """ + self.video_source.release() + + def get_all_tracking(self): + """ + Returns all the tracking objects. + """ + return self.state.object_trackings + + def get_regions_from_yolo_sam2_tracker(self, + conf, + iou, + yolo_model_checkpoint, + image_size, + sam2_model_checkpoint_path, + sam2_model_cfg, + sam2_max_frames_to_track, + prompt_type, + annotation_workaround, + video_source_path, + video_fps): + """ + Run YOLO detection and SAM tracking on the video source. + """ + + ######### YOLO DETECTION ######### + self.get_YOLO_detections(conf=conf, + iou=iou, + yolo_model_checkpoint=yolo_model_checkpoint, + video_source_path=video_source_path, + video_fps=video_fps, + max_frames_to_track=max_frames_to_track) + + ######### SAM TRACKING ######### + self.build_sam2_predictor(sam2_model_checkpoint_path, sam2_model_cfg, max_frames_to_track, prompt_type, annotation_workaround) + self.get_sam_tracking_with_yolo_prompts(video_fps=video_fps, + frames_to_track=max_frames_to_track, + video_source_path=video_source_path) + + # return the regions + return self.get_net_predictions_as_regions(video_path=video_source_path, video_fps=video_fps) + + +class VideoRectangleWithYOLOSAM2TrackerModel(VideoRectangleModel): + """ + Class representing a RectangleLabels (bounding boxes) control tag for YOLO model. + """ + + type = "VideoRectangleWithYOLOSAM2Tracker" + model_path = "yolov10x.pt" + + @classmethod + def is_control_matched(cls, control: ControlTag) -> bool: + # check object tag type + if control.objects[0].tag != "Video": + return False + if not get_bool(control.attr, "model_sam_tracker", "false"): + return False + return True + + def get_model_configs(self): + """ + + :return: + """ + conf = float(self.control.attr.get("model_conf", 0.25)) + iou = float(self.control.attr.get("model_iou", 0.70)) + yolo_model = self.control.attr.get("yolo_model", "yolov10x").lower() + yolo_model_checkpoint = yolo_model + ".pt" + image_size = int(self.control.attr.get("model_image_size", 2560)) + frames_to_track = int(self.control.attr.get("frames_to_track", 100)) + fps = int(self.control.attr.get("fps", 24)) + + return conf, iou, yolo_model_checkpoint, image_size, frames_to_track, fps + + def predict_regions(self, path) -> List[Dict]: + """ + # track regions with YOLO SAM2 tracker + :param path: + :return: + """ + single_video_annotator = SingleVideoAnnotatorModel() + + conf, iou, yolo_model_checkpoint, image_size, frames_to_track, fps = self.get_model_configs() + regions = single_video_annotator.get_regions_from_yolo_sam2_tracker( + conf=conf, + iou=iou, + yolo_model_checkpoint=yolo_model_checkpoint, + image_size=image_size, + sam2_model_cfg=SAM2_MODEL_CONFIG, + sam2_model_checkpoint_path=SAM2_MODEL_CHECKPOINT_PATH, + max_frames_to_track=frames_to_track, + prompt_type=PROMPT_TYPE, + annotation_workaround=ANNOTATION_WORKAROUND, + video_source_path=path, + video_fps=fps + ) + + return regions + +# pre-load and cache default model at startup +VideoRectangleWithYOLOSAM2TrackerModel.get_cached_model(VideoRectangleWithYOLOSAM2TrackerModel.model_path) diff --git a/label_studio_ml/examples/yolo_sam2_tracker/docker-compose.yml b/label_studio_ml/examples/yolo_sam2_tracker/docker-compose.yml new file mode 100644 index 000000000..d2b80df15 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/docker-compose.yml @@ -0,0 +1,52 @@ +version: "3.8" + +services: + MOT_yolo_sam_tracker: + container_name: MOT_yolo_sam_tracker + image: humansignal/mot_yolo_sam_tracker:v0 + build: + context: . + args: + TEST_ENV: ${TEST_ENV} + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [ gpu ] + environment: + # specify these parameters if you want to use basic auth for the model server + - BASIC_AUTH_USER= + - BASIC_AUTH_PASS= + # set the log level for the model server + - LOG_LEVEL=INFO + # any other parameters that you want to pass to the model server + - ANY=PARAMETER + # specify the number of workers and threads for the model server + - WORKERS=1 + - THREADS=8 + # specify the model directory (likely you don't need to change this) + - MODEL_DIR=/data/models + # specify device + - DEVICE=cuda # or 'cpu' (coming soon) + # SAM2 model config + - MODEL_CONFIG=./configs/sam2.1/sam2.1_hiera_l.yaml + - SEGMENT_ANYTHING_2_REPO_PATH=./sam2 + # SAM2 checkpoint + - MODEL_CHECKPOINT=sam2.1_hiera_large.pt + - MAX_FRAMES_TO_TRACK=2000 + - VIDEO_FRAME_RATE=24 + # Specify the Label Studio URL and API key to access + # uploaded, local storage and cloud storage files. + # Do not use 'localhost' as it does not work within Docker containers. + # Use prefix 'http://' or 'https://' for the URL always. + # Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows). + - LABEL_STUDIO_URL= + - LABEL_STUDIO_API_KEY= + ports: + - "9090:9090" + volumes: + - "./data/server:/data" + - ./models:/app/models # Mount the local 'models' directory + - "./cache_dir:/app/cache_dir" diff --git a/label_studio_ml/examples/yolo_sam2_tracker/model.py b/label_studio_ml/examples/yolo_sam2_tracker/model.py new file mode 100644 index 000000000..49cb52d10 --- /dev/null +++ b/label_studio_ml/examples/yolo_sam2_tracker/model.py @@ -0,0 +1,124 @@ +import logging +# from PIL import Image +# YOLO + SAM2 related imports +from typing import List, Dict +from typing import Optional + +# YOLO imports: +from control_models.base import ControlModel +from control_models.video_rectangle_with_yolo_sam2_tracker import VideoRectangleWithYOLOSAM2TrackerModel +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.response import ModelResponse + +# Register available model classes +available_model_classes = [ + VideoRectangleWithYOLOSAM2TrackerModel, +] + +logger = logging.getLogger(__name__) + +class YoloSamMultiObjectTracking(LabelStudioMLBase): + """ + YOLO_SAM model for object detection and tracking. + Detection model based on YOLO and tracking based on Segment Anything 2. + """ + + def setup(self): + """Configure any parameters of your model here""" + self.set("model_version", "yolo_sam") + + def detect_control_models(self) -> List[ControlModel]: + """Detect control models based on the labeling config. + Control models are used to predict regions for different control tags in the labeling config. + """ + control_models = [] + + for control in self.label_interface.controls: + # skipping tags without toName + if not control.to_name: + logger.warning( + f'{control.tag} {control.name} has no "toName" attribute, skipping it' + ) + continue + + # match control tag with available control models + for model_class in available_model_classes: + if model_class.is_control_matched(control): + instance = model_class.create(self, control) + if not instance: + logger.debug( + f"No instance created for {control.tag} {control.name}" + ) + continue + if not instance.label_map: + logger.error( + f"No label map built for the '{control.tag}' control tag '{instance.from_name}'.\n" + f"This indicates that your Label Studio config labels do not match the model's labels.\n" + f"To fix this, ensure that the 'value' or 'predicted_values' attribute " + f"in your Label Studio config matches one or more of these model labels.\n" + f"If you don't want to use this control tag for predictions, " + f'add `model_skip="true"` to it.\n' + f"Examples:\n" + f'