diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py
new file mode 100644
index 000000000..9d7ffb35d
--- /dev/null
+++ b/pytorch3d/implicitron/dataset/blob_loader.py
@@ -0,0 +1,526 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import functools
+import os
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from pytorch3d.implicitron.dataset import types
+from pytorch3d.implicitron.dataset.dataset_base import FrameData
+from pytorch3d.io import IO
+from pytorch3d.renderer.cameras import PerspectiveCameras
+from pytorch3d.structures.pointclouds import Pointclouds
+
+
+@dataclass
+class BlobLoader:
+ """
+ A loader for correctly (according to setup) loading blobs for FrameData.
+ This is used in the implementation of some dataset objects.
+
+ Args:
+ dataset_root: The root folder of the dataset; all the paths in jsons are
+ specified relative to this root (but not json paths themselves).
+ load_images: Enable loading the frame RGB data.
+ load_depths: Enable loading the frame depth maps.
+ load_depth_masks: Enable loading the frame depth map masks denoting the
+ depth values used for evaluation (the points consistent across views).
+ load_masks: Enable loading frame foreground masks.
+ load_point_clouds: Enable loading sequence-level point clouds.
+ max_points: Cap on the number of loaded points in the point cloud;
+ if reached, they are randomly sampled without replacement.
+ mask_images: Whether to mask the images with the loaded foreground masks;
+ 0 value is used for background.
+ mask_depths: Whether to mask the depth maps with the loaded foreground
+ masks; 0 value is used for background.
+ image_height: The height of the returned images, masks, and depth maps;
+ aspect ratio is preserved during cropping/resizing.
+ image_width: The width of the returned images, masks, and depth maps;
+ aspect ratio is preserved during cropping/resizing.
+ box_crop: Enable cropping of the image around the bounding box inferred
+ from the foreground region of the loaded segmentation mask; masks
+ and depth maps are cropped accordingly; cameras are corrected.
+ box_crop_mask_thr: The threshold used to separate pixels into foreground
+ and background based on the foreground_probability mask; if no value
+ is greater than this threshold, the loader lowers it and repeats.
+ box_crop_context: The amount of additional padding added to each
+ dimension of the cropping bounding box, relative to box size.
+ """
+
+ dataset_root: str = ""
+ load_images: bool = True
+ load_depths: bool = True
+ load_depth_masks: bool = True
+ load_masks: bool = True
+ load_point_clouds: bool = False
+ max_points: int = 0
+ mask_images: bool = False
+ mask_depths: bool = False
+ image_height: Optional[int] = 800
+ image_width: Optional[int] = 800
+ box_crop: bool = True
+ box_crop_mask_thr: float = 0.4
+ box_crop_context: float = 0.3
+ path_manager: Any = None
+
+ def load_(
+ self,
+ frame_data: FrameData,
+ entry: types.FrameAnnotation,
+ seq_annotation: types.SequenceAnnotation,
+ ) -> FrameData:
+ """Main method for loader.
+ FrameData modification done inplace
+ """
+ (
+ frame_data.fg_probability,
+ frame_data.mask_path,
+ frame_data.bbox_xywh,
+ clamp_bbox_xyxy,
+ frame_data.crop_bbox_xywh,
+ ) = self._load_crop_fg_probability(entry)
+
+ scale = min(
+ self.image_height / entry.image.size[0],
+ self.image_width / entry.image.size[1],
+ )
+ if self.load_images and entry.image is not None:
+ # original image size
+ frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
+
+ (
+ frame_data.image_rgb,
+ frame_data.image_path,
+ frame_data.mask_crop,
+ scale,
+ ) = self._load_crop_images(
+ entry, frame_data.fg_probability, clamp_bbox_xyxy
+ )
+
+ if self.load_depths and entry.depth is not None:
+ (
+ frame_data.depth_map,
+ frame_data.depth_path,
+ frame_data.depth_mask,
+ ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
+
+ if entry.viewpoint is not None:
+ frame_data.camera = self._get_pytorch3d_camera(
+ entry,
+ scale,
+ clamp_bbox_xyxy,
+ )
+
+ if self.load_point_clouds and seq_annotation.point_cloud is not None:
+ pcl_path = self._fix_point_cloud_path(seq_annotation.point_cloud.path)
+ frame_data.sequence_point_cloud = _load_pointcloud(
+ self._local_path(pcl_path), max_points=self.max_points
+ )
+ frame_data.sequence_point_cloud_path = pcl_path
+
+ def _load_crop_fg_probability(
+ self, entry: types.FrameAnnotation
+ ) -> Tuple[
+ Optional[torch.Tensor],
+ Optional[str],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ ]:
+ fg_probability = None
+ full_path = None
+ bbox_xywh = None
+ clamp_bbox_xyxy = None
+ crop_box_xywh = None
+
+ if (self.load_masks or self.box_crop) and entry.mask is not None:
+ full_path = os.path.join(self.dataset_root, entry.mask.path)
+ mask = _load_mask(self._local_path(full_path))
+
+ if mask.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
+
+ if self.box_crop:
+ clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
+ _get_clamp_bbox(
+ bbox_xywh,
+ image_path=entry.image.path,
+ box_crop_context=self.box_crop_context,
+ ),
+ image_size_hw=tuple(mask.shape[-2:]),
+ )
+ crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
+
+ mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
+
+ fg_probability, _, _ = _resize_image(
+ mask,
+ image_height=self.image_height,
+ image_width=self.image_width,
+ mode="nearest",
+ )
+
+ return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
+
+ def _load_crop_images(
+ self,
+ entry: types.FrameAnnotation,
+ fg_probability: Optional[torch.Tensor],
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
+ assert self.dataset_root is not None and entry.image is not None
+ path = os.path.join(self.dataset_root, entry.image.path)
+ image_rgb = _load_image(self._local_path(path))
+
+ if image_rgb.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
+
+ image_rgb, scale, mask_crop = _resize_image(
+ image_rgb, image_height=self.image_height, image_width=self.image_width
+ )
+
+ if self.mask_images:
+ assert fg_probability is not None
+ image_rgb *= fg_probability
+
+ return image_rgb, path, mask_crop, scale
+
+ def _load_mask_depth(
+ self,
+ entry: types.FrameAnnotation,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ fg_probability: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor]:
+ entry_depth = entry.depth
+ assert entry_depth is not None
+ path = os.path.join(self.dataset_root, entry_depth.path)
+ depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
+ )
+ depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
+
+ depth_map, _, _ = _resize_image(
+ depth_map,
+ image_height=self.image_height,
+ image_width=self.image_width,
+ mode="nearest",
+ )
+
+ if self.mask_depths:
+ assert fg_probability is not None
+ depth_map *= fg_probability
+
+ if self.load_depth_masks:
+ assert entry_depth.mask_path is not None
+ mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
+ depth_mask = _load_depth_mask(self._local_path(mask_path))
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_mask_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
+ )
+ depth_mask = _crop_around_box(
+ depth_mask, depth_mask_bbox_xyxy, mask_path
+ )
+
+ depth_mask, _, _ = _resize_image(
+ depth_mask,
+ image_height=self.image_height,
+ image_width=self.image_width,
+ mode="nearest",
+ )
+ else:
+ depth_mask = torch.ones_like(depth_map)
+
+ return depth_map, path, depth_mask
+
+ def _get_pytorch3d_camera(
+ self,
+ entry: types.FrameAnnotation,
+ scale: float,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> PerspectiveCameras:
+ entry_viewpoint = entry.viewpoint
+ assert entry_viewpoint is not None
+ # principal point and focal length
+ principal_point = torch.tensor(
+ entry_viewpoint.principal_point, dtype=torch.float
+ )
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
+
+ half_image_size_wh_orig = (
+ torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
+ )
+
+ # first, we convert from the dataset's NDC convention to pixels
+ format = entry_viewpoint.intrinsics_format
+ if format.lower() == "ndc_norm_image_bounds":
+ # this is e.g. currently used in CO3D for storing intrinsics
+ rescale = half_image_size_wh_orig
+ elif format.lower() == "ndc_isotropic":
+ rescale = half_image_size_wh_orig.min()
+ else:
+ raise ValueError(f"Unknown intrinsics format: {format}")
+
+ # principal point and focal length in pixels
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
+ focal_length_px = focal_length * rescale
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ principal_point_px -= clamp_bbox_xyxy[:2]
+
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
+ if self.image_height is None or self.image_width is None:
+ out_size = list(reversed(entry.image.size))
+ else:
+ out_size = [self.image_width, self.image_height]
+
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
+ half_min_image_size_output = half_image_size_output.min()
+
+ # rescaled principal point and focal length in ndc
+ principal_point = (
+ half_image_size_output - principal_point_px * scale
+ ) / half_min_image_size_output
+ focal_length = focal_length_px * scale / half_min_image_size_output
+
+ return PerspectiveCameras(
+ focal_length=focal_length[None],
+ principal_point=principal_point[None],
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
+ )
+
+ def _fix_point_cloud_path(self, path: str) -> str:
+ """
+ Fix up a point cloud path from the dataset.
+ Some files in Co3Dv2 have an accidental absolute path stored.
+ """
+ unwanted_prefix = (
+ "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
+ )
+ if path.startswith(unwanted_prefix):
+ path = path[len(unwanted_prefix) :]
+ return os.path.join(self.dataset_root, path)
+
+ def _local_path(self, path: str) -> str:
+ if self.path_manager is None:
+ return path
+ return self.path_manager.get_local_path(path)
+
+
+def _resize_image(
+ image, image_height, image_width, mode="bilinear"
+) -> Tuple[torch.Tensor, float, torch.Tensor]:
+ if image_height is None or image_width is None:
+ # skip the resizing
+ imre_ = torch.from_numpy(image)
+ return imre_, 1.0, torch.ones_like(imre_[:1])
+ # takes numpy array, returns pytorch tensor
+ minscale = min(
+ image_height / image.shape[-2],
+ image_width / image.shape[-1],
+ )
+ imre = torch.nn.functional.interpolate(
+ torch.from_numpy(image)[None],
+ scale_factor=minscale,
+ mode=mode,
+ align_corners=False if mode == "bilinear" else None,
+ recompute_scale_factor=True,
+ )[0]
+ imre_ = torch.zeros(image.shape[0], image_height, image_width)
+ imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
+ mask = torch.zeros(1, image_height, image_width)
+ mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
+ return imre_, minscale, mask
+
+
+def _load_image(path) -> np.ndarray:
+ with Image.open(path) as pil_im:
+ im = np.array(pil_im.convert("RGB"))
+ im = im.transpose((2, 0, 1))
+ im = im.astype(np.float32) / 255.0
+ return im
+
+
+def _load_mask(path) -> np.ndarray:
+ with Image.open(path) as pil_im:
+ mask = np.array(pil_im)
+ mask = mask.astype(np.float32) / 255.0
+ return mask[None] # fake feature channel
+
+
+def _get_bbox_from_mask(
+ mask, thr, decrease_quant: float = 0.05
+) -> Tuple[int, int, int, int]:
+ # bbox in xywh
+ masks_for_box = np.zeros_like(mask)
+ while masks_for_box.sum() <= 1.0:
+ masks_for_box = (mask > thr).astype(np.float32)
+ thr -= decrease_quant
+ if thr <= 0.0:
+ warnings.warn(
+ f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
+ )
+
+ x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
+ y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
+
+ return x0, y0, x1 - x0, y1 - y0
+
+
+def _crop_around_box(tensor, bbox, impath: str = ""):
+ # bbox is xyxy, where the upper bound is corrected with +1
+ bbox = _clamp_box_to_image_bounds_and_round(
+ bbox,
+ image_size_hw=tensor.shape[-2:],
+ )
+ tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
+ assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
+ return tensor
+
+
+def _clamp_box_to_image_bounds_and_round(
+ bbox_xyxy: torch.Tensor,
+ image_size_hw: Tuple[int, int],
+) -> torch.LongTensor:
+ bbox_xyxy = bbox_xyxy.clone()
+ bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
+ bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
+ if not isinstance(bbox_xyxy, torch.LongTensor):
+ bbox_xyxy = bbox_xyxy.round().long()
+ return bbox_xyxy # pyre-ignore [7]
+
+
+def _get_clamp_bbox(
+ bbox: torch.Tensor,
+ box_crop_context: float = 0.0,
+ image_path: str = "",
+) -> torch.Tensor:
+ # box_crop_context: rate of expansion for bbox
+ # returns possibly expanded bbox xyxy as float
+
+ bbox = bbox.clone() # do not edit bbox in place
+
+ # increase box size
+ if box_crop_context > 0.0:
+ c = box_crop_context
+ bbox = bbox.float()
+ bbox[0] -= bbox[2] * c / 2
+ bbox[1] -= bbox[3] * c / 2
+ bbox[2] += bbox[2] * c
+ bbox[3] += bbox[3] * c
+
+ if (bbox[2:] <= 1.0).any():
+ raise ValueError(
+ f"squashed image {image_path}!! The bounding box contains no pixels."
+ )
+
+ bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
+ bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
+
+ return bbox_xyxy
+
+
+def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
+ wh = xyxy[2:] - xyxy[:2]
+ xywh = torch.cat([xyxy[:2], wh])
+ return xywh
+
+
+def _load_depth(path, scale_adjustment) -> np.ndarray:
+ if not path.lower().endswith(".png"):
+ raise ValueError('unsupported depth file name "%s"' % path)
+
+ d = _load_16big_png_depth(path) * scale_adjustment
+ d[~np.isfinite(d)] = 0.0
+ return d[None] # fake feature channel
+
+
+def _load_16big_png_depth(depth_png) -> np.ndarray:
+ with Image.open(depth_png) as depth_pil:
+ # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
+ # we cast it to uint16, then reinterpret as float16, then cast to float32
+ depth = (
+ np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
+ .astype(np.float32)
+ .reshape((depth_pil.size[1], depth_pil.size[0]))
+ )
+ return depth
+
+
+def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
+ assert bbox is not None
+ assert np.prod(orig_res) > 1e-8
+ # average ratio of dimensions
+ rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
+ return bbox * rel_size
+
+
+def _load_1bit_png_mask(file: str) -> np.ndarray:
+ with Image.open(file) as pil_im:
+ mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
+ return mask
+
+
+def _load_depth_mask(path: str) -> np.ndarray:
+ if not path.lower().endswith(".png"):
+ raise ValueError('unsupported depth mask file name "%s"' % path)
+ m = _load_1bit_png_mask(path)
+ return m[None] # fake feature channel
+
+
+def _get_1d_bounds(arr) -> Tuple[int, int]:
+ nz = np.flatnonzero(arr)
+ return nz[0], nz[-1] + 1
+
+
+def _bbox_xywh_to_xyxy(
+ xywh: torch.Tensor, clamp_size: Optional[int] = None
+) -> torch.Tensor:
+ xyxy = xywh.clone()
+ if clamp_size is not None:
+ xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
+ xyxy[2:] += xyxy[:2]
+ return xyxy
+
+
+def _safe_as_tensor(data, dtype):
+ return torch.tensor(data, dtype=dtype) if data is not None else None
+
+
+# NOTE this cache is per-worker; they are implemented as processes.
+# each batch is loaded and collated by a single worker;
+# since sequences tend to co-occur within batches, this is useful.
+@functools.lru_cache(maxsize=256)
+def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
+ pcl = IO().load_pointcloud(pcl_path)
+ if max_points > 0:
+ pcl = pcl.subsample(max_points)
+
+ return pcl
diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py
index 669f4e9b6..636630680 100644
--- a/pytorch3d/implicitron/dataset/json_index_dataset.py
+++ b/pytorch3d/implicitron/dataset/json_index_dataset.py
@@ -14,8 +14,8 @@
import random
import warnings
from collections import defaultdict
+from dataclasses import field
from itertools import islice
-from pathlib import Path
from typing import (
Any,
ClassVar,
@@ -30,20 +30,18 @@
Union,
)
-import numpy as np
import torch
-from PIL import Image
+
+from pytorch3d.implicitron.dataset import types
+from pytorch3d.implicitron.dataset.blob_loader import BlobLoader
+from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
+from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
+
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
-from pytorch3d.io import IO
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
-from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
-from pytorch3d.structures.pointclouds import Pointclouds
+from pytorch3d.renderer.cameras import CamerasBase
from tqdm import tqdm
-from . import types
-from .dataset_base import DatasetBase, FrameData
-from .utils import is_known_frame_scalar
-
logger = logging.getLogger(__name__)
@@ -65,7 +63,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
A dataset with annotations in json files like the Common Objects in 3D
(CO3D) dataset.
- Args:
+ Metadata-related args::
frame_annotations_file: A zipped json file containing metadata of the
frames in the dataset, serialized List[types.FrameAnnotation].
sequence_annotations_file: A zipped json file containing metadata of the
@@ -83,6 +81,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
pick_sequence: A list of sequence names to restrict the dataset to.
exclude_sequence: A list of the names of the sequences to exclude.
limit_category_to: Restrict the dataset to the given list of categories.
+ remove_empty_masks: Removes the frames with no active foreground pixels
+ in the segmentation mask after thresholding (see box_crop_mask_thr).
+ n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
+ frames in each sequences uniformly without replacement if it has
+ more frames than that; applied before other frame-level filters.
+ seed: The seed of the random generator sampling #n_frames_per_sequence
+ random frames per sequence.
+ sort_frames: Enable frame annotations sorting to group frames from the
+ same sequences together and order them by timestamps
+ eval_batches: A list of batches that form the evaluation set;
+ list of batch-sized lists of indices corresponding to __getitem__
+ of this class, thus it can be used directly as a batch sampler.
+ eval_batch_index:
+ ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
+ A list of batches of frames described as (sequence_name, frame_idx)
+ that can form the evaluation set, `eval_batches` will be set from this.
+
+ Blob-loading parameters:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
@@ -109,23 +125,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
- remove_empty_masks: Removes the frames with no active foreground pixels
- in the segmentation mask after thresholding (see box_crop_mask_thr).
- n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
- frames in each sequences uniformly without replacement if it has
- more frames than that; applied before other frame-level filters.
- seed: The seed of the random generator sampling #n_frames_per_sequence
- random frames per sequence.
- sort_frames: Enable frame annotations sorting to group frames from the
- same sequences together and order them by timestamps
- eval_batches: A list of batches that form the evaluation set;
- list of batch-sized lists of indices corresponding to __getitem__
- of this class, thus it can be used directly as a batch sampler.
- eval_batch_index:
- ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
- A list of batches of frames described as (sequence_name, frame_idx)
- that can form the evaluation set, `eval_batches` will be set from this.
-
"""
frame_annotations_type: ClassVar[
@@ -162,12 +161,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
sort_frames: bool = False
eval_batches: Any = None
eval_batch_index: Any = None
- # frame_annots: List[FrameAnnotsEntry] = field(init=False)
- # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
+ subset_to_image_path: Any = None
+ # initialised in __post_init__
+ blob_loader: BlobLoader = field(init=False)
+ frame_annots: List[FrameAnnotsEntry] = field(init=False)
+ seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
+ _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __post_init__(self) -> None:
- # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
- self.subset_to_image_path = None
self._load_frames()
self._load_sequences()
if self.sort_frames:
@@ -175,6 +176,23 @@ def __post_init__(self) -> None:
self._load_subset_lists()
self._filter_db() # also computes sequence indices
self._extract_and_set_eval_batches()
+
+ self.blob_loader = BlobLoader(
+ dataset_root=self.dataset_root,
+ load_images=self.load_images,
+ load_depths=self.load_depths,
+ load_depth_masks=self.load_depth_masks,
+ load_masks=self.load_masks,
+ load_point_clouds=self.load_point_clouds,
+ max_points=self.max_points,
+ mask_images=self.mask_images,
+ mask_depths=self.mask_depths,
+ image_height=self.image_height,
+ image_width=self.image_width,
+ box_crop=self.box_crop,
+ box_crop_mask_thr=self.box_crop_mask_thr,
+ box_crop_context=self.box_crop_context,
+ )
logger.info(str(self))
def _extract_and_set_eval_batches(self):
@@ -190,7 +208,8 @@ def _extract_and_set_eval_batches(self):
self.eval_batch_index
)
- def join(self, other_datasets: Iterable[DatasetBase]) -> None:
+ # pyre-ignore
+ def join(self, other_datasets: Iterable["JsonIndexDataset"]) -> None:
"""
Join the dataset with other JsonIndexDataset objects.
@@ -200,19 +219,16 @@ def join(self, other_datasets: Iterable[DatasetBase]) -> None:
"""
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
raise ValueError("This function can only join a list of JsonIndexDataset")
- # pyre-ignore[16]
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
- # pyre-ignore[16]
self.seq_annots.update(
# https://gist.github.com/treyhunner/f35292e676efa0be1728
functools.reduce(
lambda a, b: {**a, **b},
- [d.seq_annots for d in other_datasets], # pyre-ignore[16]
+ [d.seq_annots for d in other_datasets],
)
)
all_eval_batches = [
self.eval_batches,
- # pyre-ignore
*[d.eval_batches for d in other_datasets],
]
if not (
@@ -251,7 +267,7 @@ def seq_frame_index_to_dataset_index(
allow_missing_indices: bool = False,
remove_missing_indices: bool = False,
suppress_missing_index_warning: bool = True,
- ) -> List[List[Union[Optional[int], int]]]:
+ ) -> Union[List[List[Optional[int]]], List[List[int]]]:
"""
Obtain indices into the dataset object given a list of frame ids.
@@ -279,11 +295,9 @@ def seq_frame_index_to_dataset_index(
"""
_dataset_seq_frame_n_index = {
seq: {
- # pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx
}
- # pyre-ignore[16]
for seq, seq_idx in self._seq_to_idx.items()
}
@@ -306,7 +320,6 @@ def _get_dataset_idx(
# Check that the loaded frame path is consistent
# with the one stored in self.frame_annots.
assert os.path.normpath(
- # pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath(
path
@@ -323,9 +336,7 @@ def _get_dataset_idx(
valid_dataset_idx = [
[b for b in batch if b is not None] for batch in dataset_idx
]
- return [ # pyre-ignore[7]
- batch for batch in valid_dataset_idx if len(batch) > 0
- ]
+ return [batch for batch in valid_dataset_idx if len(batch) > 0]
return dataset_idx
@@ -358,7 +369,7 @@ def subset_from_frame_index(
# Deep copy the whole dataset except frame_annots, which are large so we
# deep copy only the requested subset of frame_annots.
- memo = {id(self.frame_annots): None} # pyre-ignore[16]
+ memo = {id(self.frame_annots): None}
dataset_new = copy.deepcopy(self, memo)
dataset_new.frame_annots = copy.deepcopy(
[self.frame_annots[i] for i in valid_dataset_indices]
@@ -386,11 +397,9 @@ def subset_from_frame_index(
return dataset_new
def __str__(self) -> str:
- # pyre-ignore[16]
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
def __len__(self) -> int:
- # pyre-ignore[16]
return len(self.frame_annots)
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
@@ -402,7 +411,6 @@ def get_all_train_cameras(self) -> CamerasBase:
"""
logger.info("Loading all train cameras.")
cameras = []
- # pyre-ignore[16]
for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
frame_type = self._get_frame_type(frame_annot)
if frame_type is None:
@@ -412,12 +420,10 @@ def get_all_train_cameras(self) -> CamerasBase:
return join_cameras_as_batch(cameras)
def __getitem__(self, index) -> FrameData:
- # pyre-ignore[16]
if index >= len(self.frame_annots):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
- # pyre-ignore[16]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
@@ -435,236 +441,12 @@ def __getitem__(self, index) -> FrameData:
else None,
)
- # The rest of the fields are optional
+ # Optional field
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
-
- (
- frame_data.fg_probability,
- frame_data.mask_path,
- frame_data.bbox_xywh,
- clamp_bbox_xyxy,
- frame_data.crop_bbox_xywh,
- ) = self._load_crop_fg_probability(entry)
-
- scale = 1.0
- if self.load_images and entry.image is not None:
- # original image size
- frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
-
- (
- frame_data.image_rgb,
- frame_data.image_path,
- frame_data.mask_crop,
- scale,
- ) = self._load_crop_images(
- entry, frame_data.fg_probability, clamp_bbox_xyxy
- )
-
- if self.load_depths and entry.depth is not None:
- (
- frame_data.depth_map,
- frame_data.depth_path,
- frame_data.depth_mask,
- ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
-
- if entry.viewpoint is not None:
- frame_data.camera = self._get_pytorch3d_camera(
- entry,
- scale,
- clamp_bbox_xyxy,
- )
-
- if self.load_point_clouds and point_cloud is not None:
- pcl_path = self._fix_point_cloud_path(point_cloud.path)
- frame_data.sequence_point_cloud = _load_pointcloud(
- self._local_path(pcl_path), max_points=self.max_points
- )
- frame_data.sequence_point_cloud_path = pcl_path
-
- return frame_data
-
- def _fix_point_cloud_path(self, path: str) -> str:
- """
- Fix up a point cloud path from the dataset.
- Some files in Co3Dv2 have an accidental absolute path stored.
- """
- unwanted_prefix = (
- "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
- )
- if path.startswith(unwanted_prefix):
- path = path[len(unwanted_prefix) :]
- return os.path.join(self.dataset_root, path)
-
- def _load_crop_fg_probability(
- self, entry: types.FrameAnnotation
- ) -> Tuple[
- Optional[torch.Tensor],
- Optional[str],
- Optional[torch.Tensor],
- Optional[torch.Tensor],
- Optional[torch.Tensor],
- ]:
- fg_probability = None
- full_path = None
- bbox_xywh = None
- clamp_bbox_xyxy = None
- crop_box_xywh = None
-
- if (self.load_masks or self.box_crop) and entry.mask is not None:
- full_path = os.path.join(self.dataset_root, entry.mask.path)
- mask = _load_mask(self._local_path(full_path))
-
- if mask.shape[-2:] != entry.image.size:
- raise ValueError(
- f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
- )
-
- bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
-
- if self.box_crop:
- clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
- _get_clamp_bbox(
- bbox_xywh,
- image_path=entry.image.path,
- box_crop_context=self.box_crop_context,
- ),
- image_size_hw=tuple(mask.shape[-2:]),
- )
- crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
-
- mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
-
- fg_probability, _, _ = self._resize_image(mask, mode="nearest")
-
- return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
-
- def _load_crop_images(
- self,
- entry: types.FrameAnnotation,
- fg_probability: Optional[torch.Tensor],
- clamp_bbox_xyxy: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
- assert self.dataset_root is not None and entry.image is not None
- path = os.path.join(self.dataset_root, entry.image.path)
- image_rgb = _load_image(self._local_path(path))
-
- if image_rgb.shape[-2:] != entry.image.size:
- raise ValueError(
- f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
- )
-
- if self.box_crop:
- assert clamp_bbox_xyxy is not None
- image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
-
- image_rgb, scale, mask_crop = self._resize_image(image_rgb)
-
- if self.mask_images:
- assert fg_probability is not None
- image_rgb *= fg_probability
-
- return image_rgb, path, mask_crop, scale
-
- def _load_mask_depth(
- self,
- entry: types.FrameAnnotation,
- clamp_bbox_xyxy: Optional[torch.Tensor],
- fg_probability: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, str, torch.Tensor]:
- entry_depth = entry.depth
- assert entry_depth is not None
- path = os.path.join(self.dataset_root, entry_depth.path)
- depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
-
- if self.box_crop:
- assert clamp_bbox_xyxy is not None
- depth_bbox_xyxy = _rescale_bbox(
- clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
- )
- depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
-
- depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
-
- if self.mask_depths:
- assert fg_probability is not None
- depth_map *= fg_probability
-
- if self.load_depth_masks:
- assert entry_depth.mask_path is not None
- mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
- depth_mask = _load_depth_mask(self._local_path(mask_path))
-
- if self.box_crop:
- assert clamp_bbox_xyxy is not None
- depth_mask_bbox_xyxy = _rescale_bbox(
- clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
- )
- depth_mask = _crop_around_box(
- depth_mask, depth_mask_bbox_xyxy, mask_path
- )
-
- depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
- else:
- depth_mask = torch.ones_like(depth_map)
-
- return depth_map, path, depth_mask
-
- def _get_pytorch3d_camera(
- self,
- entry: types.FrameAnnotation,
- scale: float,
- clamp_bbox_xyxy: Optional[torch.Tensor],
- ) -> PerspectiveCameras:
- entry_viewpoint = entry.viewpoint
- assert entry_viewpoint is not None
- # principal point and focal length
- principal_point = torch.tensor(
- entry_viewpoint.principal_point, dtype=torch.float
- )
- focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
-
- half_image_size_wh_orig = (
- torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
- )
-
- # first, we convert from the dataset's NDC convention to pixels
- format = entry_viewpoint.intrinsics_format
- if format.lower() == "ndc_norm_image_bounds":
- # this is e.g. currently used in CO3D for storing intrinsics
- rescale = half_image_size_wh_orig
- elif format.lower() == "ndc_isotropic":
- rescale = half_image_size_wh_orig.min()
- else:
- raise ValueError(f"Unknown intrinsics format: {format}")
-
- # principal point and focal length in pixels
- principal_point_px = half_image_size_wh_orig - principal_point * rescale
- focal_length_px = focal_length * rescale
- if self.box_crop:
- assert clamp_bbox_xyxy is not None
- principal_point_px -= clamp_bbox_xyxy[:2]
-
- # now, convert from pixels to PyTorch3D v0.5+ NDC convention
- if self.image_height is None or self.image_width is None:
- out_size = list(reversed(entry.image.size))
- else:
- out_size = [self.image_width, self.image_height]
-
- half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
- half_min_image_size_output = half_image_size_output.min()
-
- # rescaled principal point and focal length in ndc
- principal_point = (
- half_image_size_output - principal_point_px * scale
- ) / half_min_image_size_output
- focal_length = focal_length_px * scale / half_min_image_size_output
-
- return PerspectiveCameras(
- focal_length=focal_length[None],
- principal_point=principal_point[None],
- R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
- T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
+ self.blob_loader.load_(
+ frame_data, entry, self.seq_annots[entry.sequence_name]
)
+ return frame_data
def _load_frames(self) -> None:
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
@@ -675,7 +457,6 @@ def _load_frames(self) -> None:
)
if not frame_annots_list:
raise ValueError("Empty dataset!")
- # pyre-ignore[16]
self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
]
@@ -687,7 +468,6 @@ def _load_sequences(self) -> None:
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
if not seq_annots:
raise ValueError("Empty sequences file!")
- # pyre-ignore[16]
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None:
@@ -703,7 +483,6 @@ def _load_subset_lists(self) -> None:
for subset, frames in subset_to_seq_frame.items()
for _, _, path in frames
}
- # pyre-ignore[16]
for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None
@@ -716,7 +495,6 @@ def _load_subset_lists(self) -> None:
def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp
- # pyre-ignore[16]
self.frame_annots = sorted(
self.frame_annots,
key=lambda f: (
@@ -728,7 +506,6 @@ def _sort_frames(self) -> None:
def _filter_db(self) -> None:
if self.remove_empty_masks:
logger.info("Removing images with empty masks.")
- # pyre-ignore[16]
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
@@ -769,7 +546,6 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
if len(self.limit_category_to) > 0:
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
- # pyre-ignore[16]
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
@@ -807,7 +583,6 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
if self.n_frames_per_sequence > 0:
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
- # pyre-ignore[16]
for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
@@ -837,51 +612,16 @@ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
self._invalidate_seq_to_idx()
if filter_seq_annots:
- # pyre-ignore[16]
self.seq_annots = {
- k: v
- for k, v in self.seq_annots.items()
- # pyre-ignore[16]
- if k in self._seq_to_idx
+ k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
- # pyre-ignore[16]
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
- # pyre-ignore[16]
self._seq_to_idx = seq_to_idx
- def _resize_image(
- self, image, mode="bilinear"
- ) -> Tuple[torch.Tensor, float, torch.Tensor]:
- image_height, image_width = self.image_height, self.image_width
- if image_height is None or image_width is None:
- # skip the resizing
- imre_ = torch.from_numpy(image)
- return imre_, 1.0, torch.ones_like(imre_[:1])
- # takes numpy array, returns pytorch tensor
- minscale = min(
- image_height / image.shape[-2],
- image_width / image.shape[-1],
- )
- imre = torch.nn.functional.interpolate(
- torch.from_numpy(image)[None],
- scale_factor=minscale,
- mode=mode,
- align_corners=False if mode == "bilinear" else None,
- recompute_scale_factor=True,
- )[0]
- # pyre-fixme[19]: Expected 1 positional argument.
- imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
- imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
- # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
- # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
- mask = torch.zeros(1, self.image_height, self.image_width)
- mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
- return imre_, minscale, mask
-
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
@@ -894,7 +634,6 @@ def get_frame_numbers_and_timestamps(
for idx in idxs:
if (
subset_filter is not None
- # pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
and self.frame_annots[idx]["subset"] not in subset_filter
):
continue
@@ -907,7 +646,6 @@ def get_frame_numbers_and_timestamps(
def category_to_sequence_names(self) -> Dict[str, List[str]]:
c2seq = defaultdict(list)
- # pyre-ignore
for sequence_name, sa in self.seq_annots.items():
c2seq[sa.category].append(sequence_name)
return dict(c2seq)
@@ -920,167 +658,5 @@ def _seq_name_to_seed(seq_name) -> int:
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
-def _load_image(path) -> np.ndarray:
- with Image.open(path) as pil_im:
- im = np.array(pil_im.convert("RGB"))
- im = im.transpose((2, 0, 1))
- im = im.astype(np.float32) / 255.0
- return im
-
-
-def _load_16big_png_depth(depth_png) -> np.ndarray:
- with Image.open(depth_png) as depth_pil:
- # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
- # we cast it to uint16, then reinterpret as float16, then cast to float32
- depth = (
- np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
- .astype(np.float32)
- .reshape((depth_pil.size[1], depth_pil.size[0]))
- )
- return depth
-
-
-def _load_1bit_png_mask(file: str) -> np.ndarray:
- with Image.open(file) as pil_im:
- mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
- return mask
-
-
-def _load_depth_mask(path: str) -> np.ndarray:
- if not path.lower().endswith(".png"):
- raise ValueError('unsupported depth mask file name "%s"' % path)
- m = _load_1bit_png_mask(path)
- return m[None] # fake feature channel
-
-
-def _load_depth(path, scale_adjustment) -> np.ndarray:
- if not path.lower().endswith(".png"):
- raise ValueError('unsupported depth file name "%s"' % path)
-
- d = _load_16big_png_depth(path) * scale_adjustment
- d[~np.isfinite(d)] = 0.0
- return d[None] # fake feature channel
-
-
-def _load_mask(path) -> np.ndarray:
- with Image.open(path) as pil_im:
- mask = np.array(pil_im)
- mask = mask.astype(np.float32) / 255.0
- return mask[None] # fake feature channel
-
-
-def _get_1d_bounds(arr) -> Tuple[int, int]:
- nz = np.flatnonzero(arr)
- return nz[0], nz[-1] + 1
-
-
-def _get_bbox_from_mask(
- mask, thr, decrease_quant: float = 0.05
-) -> Tuple[int, int, int, int]:
- # bbox in xywh
- masks_for_box = np.zeros_like(mask)
- while masks_for_box.sum() <= 1.0:
- masks_for_box = (mask > thr).astype(np.float32)
- thr -= decrease_quant
- if thr <= 0.0:
- warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
-
- x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
- y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
-
- return x0, y0, x1 - x0, y1 - y0
-
-
-def _get_clamp_bbox(
- bbox: torch.Tensor,
- box_crop_context: float = 0.0,
- image_path: str = "",
-) -> torch.Tensor:
- # box_crop_context: rate of expansion for bbox
- # returns possibly expanded bbox xyxy as float
-
- bbox = bbox.clone() # do not edit bbox in place
-
- # increase box size
- if box_crop_context > 0.0:
- c = box_crop_context
- bbox = bbox.float()
- bbox[0] -= bbox[2] * c / 2
- bbox[1] -= bbox[3] * c / 2
- bbox[2] += bbox[2] * c
- bbox[3] += bbox[3] * c
-
- if (bbox[2:] <= 1.0).any():
- raise ValueError(
- f"squashed image {image_path}!! The bounding box contains no pixels."
- )
-
- bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
- bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
-
- return bbox_xyxy
-
-
-def _crop_around_box(tensor, bbox, impath: str = ""):
- # bbox is xyxy, where the upper bound is corrected with +1
- bbox = _clamp_box_to_image_bounds_and_round(
- bbox,
- image_size_hw=tensor.shape[-2:],
- )
- tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
- assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
- return tensor
-
-
-def _clamp_box_to_image_bounds_and_round(
- bbox_xyxy: torch.Tensor,
- image_size_hw: Tuple[int, int],
-) -> torch.LongTensor:
- bbox_xyxy = bbox_xyxy.clone()
- bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
- bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
- if not isinstance(bbox_xyxy, torch.LongTensor):
- bbox_xyxy = bbox_xyxy.round().long()
- return bbox_xyxy # pyre-ignore [7]
-
-
-def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
- assert bbox is not None
- assert np.prod(orig_res) > 1e-8
- # average ratio of dimensions
- rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
- return bbox * rel_size
-
-
-def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
- wh = xyxy[2:] - xyxy[:2]
- xywh = torch.cat([xyxy[:2], wh])
- return xywh
-
-
-def _bbox_xywh_to_xyxy(
- xywh: torch.Tensor, clamp_size: Optional[int] = None
-) -> torch.Tensor:
- xyxy = xywh.clone()
- if clamp_size is not None:
- xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
- xyxy[2:] += xyxy[:2]
- return xyxy
-
-
def _safe_as_tensor(data, dtype):
- if data is None:
- return None
- return torch.tensor(data, dtype=dtype)
-
-
-# NOTE this cache is per-worker; they are implemented as processes.
-# each batch is loaded and collated by a single worker;
-# since sequences tend to co-occur within batches, this is useful.
-@functools.lru_cache(maxsize=256)
-def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
- pcl = IO().load_pointcloud(pcl_path)
- if max_points > 0:
- pcl = pcl.subsample(max_points)
-
- return pcl
+ return torch.tensor(data, dtype=dtype) if data is not None else None
diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py
index 6d0be0362..284e903a0 100644
--- a/pytorch3d/implicitron/dataset/visualize.py
+++ b/pytorch3d/implicitron/dataset/visualize.py
@@ -44,7 +44,6 @@ def get_implicitron_sequence_pointcloud(
sequence_entries = [
ei
for ei in sequence_entries
- # pyre-ignore[16]
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]
diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py
index 999dfc924..48a8421bb 100644
--- a/tests/implicitron/test_bbox.py
+++ b/tests/implicitron/test_bbox.py
@@ -9,11 +9,18 @@
import numpy as np
import torch
-from pytorch3d.implicitron.dataset.json_index_dataset import (
+from pytorch3d.implicitron.dataset.blob_loader import (
_bbox_xywh_to_xyxy,
_bbox_xyxy_to_xywh,
+ _clamp_box_to_image_bounds_and_round,
+ _crop_around_box,
+ _get_1d_bounds,
_get_bbox_from_mask,
+ _get_clamp_bbox,
+ _rescale_bbox,
+ _resize_image,
)
+
from tests.common_testing import TestCaseMixin
@@ -76,3 +83,59 @@ def test_mask_to_bbox(self):
expected_bbox_xywh = [2, 1, 2, 1]
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
self.assertClose(bbox_xywh, expected_bbox_xywh)
+
+ def test_crop_around_box(self):
+ bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max)
+ image = torch.LongTensor(
+ [
+ [0, 0, 10, 20],
+ [10, 20, 5, 1],
+ [10, 20, 1, 1],
+ [5, 4, 0, 1],
+ ]
+ )
+ cropped = _crop_around_box(image, bbox)
+ self.assertClose(cropped, image[1:3, 0:2])
+
+ def test_clamp_box_to_image_bounds_and_round(self):
+ bbox = torch.LongTensor([0, 1, 10, 12])
+ image_size = (5, 6)
+ expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
+ clamped_bbox = _clamp_box_to_image_bounds_and_round(bbox, image_size)
+ self.assertClose(clamped_bbox, expected_clamped_bbox)
+
+ def test_get_clamp_bbox(self):
+ bbox_xywh = torch.LongTensor([1, 1, 4, 5])
+ clamped_bbox_xyxy = _get_clamp_bbox(bbox_xywh, box_crop_context=2)
+ # size multiplied by 2 and added coordinates
+ self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))
+
+ def test_rescale_bbox(self):
+ bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
+ original_resolution = (4, 4)
+ new_resolution = (8, 8) # twice bigger
+ rescaled_bbox = _rescale_bbox(bbox, original_resolution, new_resolution)
+ self.assertClose(bbox * 2, rescaled_bbox)
+
+ def test_get_1d_bounds(self):
+ array = [0, 1, 2]
+ bounds = _get_1d_bounds(array)
+ # make nonzero 1d bounds of image
+ self.assertClose(bounds, [1, 3])
+
+ def test_resize_image(self):
+ image = np.random.rand(3, 300, 500) # rgb image 300x500
+ expected_shape = (150, 250)
+
+ resized_image, scale, mask_crop = _resize_image(
+ image, image_height=expected_shape[0], image_width=expected_shape[1]
+ )
+
+ original_shape = image.shape[-2:]
+ expected_scale = min(
+ expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
+ )
+
+ self.assertEqual(scale, expected_scale)
+ self.assertEqual(resized_image.shape[-2:], expected_shape)
+ self.assertEqual(mask_crop.shape[-2:], expected_shape)
diff --git a/tests/implicitron/test_blob_loader.py b/tests/implicitron/test_blob_loader.py
new file mode 100644
index 000000000..fd8d8fd81
--- /dev/null
+++ b/tests/implicitron/test_blob_loader.py
@@ -0,0 +1,167 @@
+import contextlib
+import gzip
+import os
+import unittest
+from typing import List
+
+import numpy as np
+import torch
+
+from pytorch3d.implicitron.dataset import types
+from pytorch3d.implicitron.dataset.blob_loader import (
+ _load_16big_png_depth,
+ _load_1bit_png_mask,
+ _load_depth,
+ _load_depth_mask,
+ _load_image,
+ _load_mask,
+ BlobLoader,
+)
+from pytorch3d.implicitron.tools.config import get_default_args
+from pytorch3d.renderer.cameras import PerspectiveCameras
+
+from tests.common_testing import TestCaseMixin
+from tests.implicitron.common_resources import get_skateboard_data
+
+
+class TestBlobLoader(TestCaseMixin, unittest.TestCase):
+ def setUp(self):
+ torch.manual_seed(42)
+
+ category = "skateboard"
+ stack = contextlib.ExitStack()
+ self.dataset_root, self.path_manager = stack.enter_context(
+ get_skateboard_data()
+ )
+ self.addCleanup(stack.close)
+ self.image_height = 768
+ self.image_width = 512
+
+ self.blob_loader = BlobLoader(
+ image_height=self.image_height,
+ image_width=self.image_width,
+ dataset_root=self.dataset_root,
+ path_manager=self.path_manager,
+ )
+
+ # loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
+ frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
+ local_file = self.path_manager.get_local_path(frame_file)
+ with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
+ frame_annots_list = types.load_dataclass(
+ zipfile, List[types.FrameAnnotation]
+ )
+ self.frame_annotation = frame_annots_list[0]
+
+ def test_BlobLoader_args(self):
+ # test that BlobLoader works with get_default_args
+ get_default_args(BlobLoader)
+
+ def test_fix_point_cloud_path(self):
+ """Some files in Co3Dv2 have an accidental absolute path stored."""
+ original_path = "some_file_path"
+ modified_path = self.blob_loader._fix_point_cloud_path(original_path)
+ assert original_path in modified_path
+ assert self.blob_loader.dataset_root in modified_path
+
+ def test_load_(self):
+ (
+ fg_probability,
+ mask_path,
+ bbox_xywh,
+ clamp_bbox_xyxy,
+ crop_bbox_xywh,
+ ) = self.blob_loader._load_crop_fg_probability(self.frame_annotation)
+
+ assert mask_path
+ assert torch.is_tensor(fg_probability)
+ assert torch.is_tensor(bbox_xywh)
+ assert torch.is_tensor(clamp_bbox_xyxy)
+ assert torch.is_tensor(crop_bbox_xywh)
+ # assert bboxes shape
+ self.assertEqual(
+ fg_probability.shape, torch.Size([1, self.image_height, self.image_width])
+ )
+ self.assertEqual(bbox_xywh.shape, torch.Size([4]))
+ self.assertEqual(clamp_bbox_xyxy.shape, torch.Size([4]))
+ self.assertEqual(crop_bbox_xywh.shape, torch.Size([4]))
+ (image_rgb, image_path, mask_crop, scale,) = self.blob_loader._load_crop_images(
+ self.frame_annotation, fg_probability, clamp_bbox_xyxy
+ )
+ assert torch.is_tensor(image_rgb)
+ assert image_path
+ assert torch.is_tensor(mask_crop)
+ assert scale
+ # assert image and mask shapes
+ self.assertEqual(
+ image_rgb.shape, torch.Size([3, self.image_height, self.image_width])
+ )
+ self.assertEqual(
+ mask_crop.shape, torch.Size([1, self.image_height, self.image_width])
+ )
+
+ (depth_map, depth_path, depth_mask,) = self.blob_loader._load_mask_depth(
+ self.frame_annotation,
+ clamp_bbox_xyxy,
+ fg_probability,
+ )
+ assert torch.is_tensor(depth_map)
+ assert depth_path
+ assert torch.is_tensor(depth_mask)
+ # assert image and mask shapes
+ self.assertEqual(
+ depth_map.shape, torch.Size([1, self.image_height, self.image_width])
+ )
+ self.assertEqual(
+ depth_mask.shape, torch.Size([1, self.image_height, self.image_width])
+ )
+
+ camera = self.blob_loader._get_pytorch3d_camera(
+ self.frame_annotation,
+ scale,
+ clamp_bbox_xyxy,
+ )
+ self.assertEqual(type(camera), PerspectiveCameras)
+
+ def test_load_image(self):
+ path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
+ local_path = self.path_manager.get_local_path(path)
+ image = _load_image(local_path)
+ self.assertEqual(image.dtype, np.float32)
+ assert np.max(image) <= 1.0
+ assert np.min(image) >= 0.0
+
+ def test_load_mask(self):
+ path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
+ mask = _load_mask(path)
+ self.assertEqual(mask.dtype, np.float32)
+ assert np.max(mask) <= 1.0
+ assert np.min(mask) >= 0.0
+
+ def test_load_depth(self):
+ path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
+ depth_map = _load_depth(path, self.frame_annotation.depth.scale_adjustment)
+ self.assertEqual(depth_map.dtype, np.float32)
+ self.assertEqual(len(depth_map.shape), 3)
+
+ def test_load_16big_png_depth(self):
+ path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
+ depth_map = _load_16big_png_depth(path)
+ self.assertEqual(depth_map.dtype, np.float32)
+ self.assertEqual(len(depth_map.shape), 2)
+
+ def test_load_1bit_png_mask(self):
+ mask_path = os.path.join(
+ self.dataset_root, self.frame_annotation.depth.mask_path
+ )
+ mask = _load_1bit_png_mask(mask_path)
+ self.assertEqual(mask.dtype, np.float32)
+ self.assertEqual(len(mask.shape), 2)
+
+ def test_load_depth_mask(self):
+ mask_path = os.path.join(
+ self.dataset_root, self.frame_annotation.depth.mask_path
+ )
+ mask = _load_depth_mask(mask_path)
+ self.assertEqual(mask.dtype, np.float32)
+ self.assertEqual(len(mask.shape), 3)