Skip to content

Add tensor-based annotation storage to reduce DDP RAM usage with large COCO-format datasets #1885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import copy
import dataclasses
import gc
import json
import os
import pickle
from collections import defaultdict
from typing import List, Optional, Tuple

import numpy as np
from typing import List, Optional, Tuple
import torch

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.common.deprecate import deprecated_parameter
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.common.registry import register_dataset
from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy_inplace
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.training.utils.detection_utils import change_bbox_bounds_for_image_size

logger = get_logger(__name__)
Expand Down Expand Up @@ -42,6 +45,7 @@ def __init__(
with_crowd: bool = True,
class_ids_to_ignore: Optional[List[int]] = None,
tight_box_rotation=None,
use_tensor_backed_storage: bool = False,
*args,
**kwargs,
):
Expand All @@ -52,6 +56,7 @@ def __init__(
:param with_crowd: Add the crowd groundtruths to __getitem__
:param class_ids_to_ignore: List of class ids to ignore in the dataset. By default, doesnt ignore any class.
:param tight_box_rotation: This parameter is deprecated and will be removed in a SuperGradients 3.8.
:param use_tensor_backed_storage: Whether to use tensor backed storage to mitigate python memory leak with large datasets ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give an estimate what's considered "large" as a recommendation, from your experience.

"""
if tight_box_rotation is not None:
logger.warning(
Expand All @@ -61,6 +66,7 @@ def __init__(
self.json_annotation_file = json_annotation_file
self.with_crowd = with_crowd
self.class_ids_to_ignore = class_ids_to_ignore or []
self.use_tensor_backed_storage = use_tensor_backed_storage

target_fields = ["target", "crowd_target"] if self.with_crowd else ["target"]
kwargs["target_fields"] = target_fields
Expand All @@ -80,6 +86,11 @@ def __init__(
"Most likely this indicates an error in your all_classes_list parameter"
)

@staticmethod
def _serialize_annotations(data):
buffer = pickle.dumps(data, protocol=-1)
return torch.frombuffer(buffer, dtype=torch.uint8)

def _setup_data_source(self) -> int:
"""
Parse COCO annotation file
Expand All @@ -105,8 +116,23 @@ def _setup_data_source(self) -> int:

self.original_classes = list(all_class_names)
self.classes = copy.deepcopy(self.original_classes)

self._annotations = annotations
return len(annotations)

if self.use_tensor_backed_storage:
self._annotations = [COCOFormatDetectionDataset._serialize_annotations(x) for x in self._annotations]

del annotations
gc.collect()

self._addr = torch.tensor([len(x) for x in self._annotations], dtype=torch.int64)
self._addr = torch.cumsum(self._addr, dim=0)
self._annotations = torch.concatenate(self._annotations)

return len(self._addr)

else:
return len(self._annotations)

@property
def _all_classes(self) -> List[str]:
Expand All @@ -125,7 +151,12 @@ def _load_annotation(self, sample_id: int) -> dict:
:return img_path: Path to the associated image
"""

annotation = self._annotations[sample_id]
if self.use_tensor_backed_storage:
start_addr = 0 if sample_id == 0 else self._addr[sample_id - 1].item()
end_addr = self._addr[sample_id].item()
annotation = pickle.loads(self._annotations[start_addr:end_addr].numpy().data)
else:
annotation = self._annotations[sample_id]

width = annotation.image_width
height = annotation.image_height
Expand Down
Loading