Skip to content

Commit 47e3607

Browse files
Pau Gargallofacebook-github-bot
authored andcommitted
pyre strict in dataset and io
Differential Revision: D73199680 fbshipit-source-id: 39e745f4669fe1c84fda951ec879160f80eb771c
1 parent 51dad31 commit 47e3607

2 files changed

Lines changed: 110 additions & 111 deletions

File tree

opensfm/dataset.py

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pyre-unsafe
1+
# pyre-strict
22
import gzip
33
import json
44
import logging
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, IO, List, Optional, Tuple
99

1010
import numpy as np
11+
from numpy.typing import NDArray
1112
from opensfm import config, features, geo, io, masking, pygeometry, pymap, rig, types
1213
from opensfm.dataset_base import DataSetBase
1314
from PIL.PngImagePlugin import PngImageFile
@@ -30,14 +31,17 @@ class DataSet(DataSetBase):
3031
"""
3132

3233
io_handler: io.IoFilesystemBase = io.IoFilesystemDefault()
33-
config = None
34+
config: Dict[str, Any] = {}
3435
image_files: Dict[str, str] = {}
3536
mask_files: Dict[str, str] = {}
3637
image_list: List[str] = []
3738

38-
def __init__(self, data_path: str, io_handler=io.IoFilesystemDefault) -> None:
39+
def __init__(
40+
self, data_path: str, io_handler: Optional[io.IoFilesystemBase] = None
41+
) -> None:
3942
"""Init dataset associated to a folder."""
40-
self.io_handler = io_handler
43+
if io_handler is not None:
44+
self.io_handler = io_handler
4145
self.data_path = data_path
4246
self.load_config()
4347
self.load_image_list()
@@ -80,7 +84,7 @@ def _image_file(self, image: str) -> str:
8084
"""Path to the image file."""
8185
return self.image_files[image]
8286

83-
def open_image_file(self, image: str) -> IO[Any]:
87+
def open_image_file(self, image: str) -> IO[bytes]:
8488
"""Open image file and return file object."""
8589
return self.io_handler.open_rb(self._image_file(image))
8690

@@ -90,7 +94,7 @@ def load_image(
9094
unchanged: bool = False,
9195
anydepth: bool = False,
9296
grayscale: bool = False,
93-
) -> np.ndarray:
97+
) -> NDArray:
9498
"""Load image pixels as numpy array.
9599
96100
The array is 3D, indexed by y-coord, x-coord, channel.
@@ -117,7 +121,7 @@ def load_mask_list(self) -> None:
117121
else:
118122
self._set_mask_path(os.path.join(self.data_path, "masks"))
119123

120-
def load_mask(self, image: str) -> Optional[np.ndarray]:
124+
def load_mask(self, image: str) -> Optional[NDArray]:
121125
"""Load image mask if it exists, otherwise return None."""
122126
if image in self.mask_files:
123127
mask_path = self.mask_files[image]
@@ -138,7 +142,7 @@ def _instances_path(self) -> str:
138142
def _instances_file(self, image: str) -> str:
139143
return os.path.join(self._instances_path(), image + ".png")
140144

141-
def load_instances(self, image: str) -> Optional[np.ndarray]:
145+
def load_instances(self, image: str) -> Optional[NDArray]:
142146
"""Load image instances file if it exists, otherwise return None."""
143147
instances_file = self._instances_file(image)
144148
if self.io_handler.isfile(instances_file):
@@ -153,10 +157,10 @@ def _segmentation_path(self) -> str:
153157
def _segmentation_file(self, image: str) -> str:
154158
return os.path.join(self._segmentation_path(), image + ".png")
155159

156-
def segmentation_labels(self) -> List[Any]:
160+
def segmentation_labels(self) -> List[Dict[str, Any]]:
157161
return []
158162

159-
def load_segmentation(self, image: str) -> Optional[np.ndarray]:
163+
def load_segmentation(self, image: str) -> Optional[NDArray]:
160164
"""Load image segmentation if it exists, otherwise return None."""
161165
segmentation_file = self._segmentation_file(image)
162166
if self.io_handler.isfile(segmentation_file):
@@ -318,12 +322,12 @@ def _words_file(self, image: str) -> str:
318322
def words_exist(self, image: str) -> bool:
319323
return self.io_handler.isfile(self._words_file(image))
320324

321-
def load_words(self, image: str) -> np.ndarray:
325+
def load_words(self, image: str) -> NDArray:
322326
with self.io_handler.open_rb(self._words_file(image)) as f:
323327
s = np.load(f)
324328
return s["words"].astype(np.int32)
325329

326-
def save_words(self, image: str, words: np.ndarray) -> None:
330+
def save_words(self, image: str, words: NDArray) -> None:
327331
with self.io_handler.open_wb(self._words_file(image)) as f:
328332
np.savez_compressed(f, words=words.astype(np.uint16))
329333

@@ -338,7 +342,7 @@ def _matches_file(self, image: str) -> str:
338342
def matches_exists(self, image: str) -> bool:
339343
return self.io_handler.isfile(self._matches_file(image))
340344

341-
def load_matches(self, image: str) -> Dict[str, np.ndarray]:
345+
def load_matches(self, image: str) -> Dict[str, NDArray]:
342346
# Prevent pickling of anything except what we strictly need
343347
# as 'pickle.load' is RCE-prone. Will raise on any class other
344348
# than the numpy ones we allow.
@@ -363,7 +367,7 @@ def find_class(self, module, name):
363367
matches = MatchingUnpickler(BytesIO(gzip.decompress(fin.read()))).load()
364368
return matches
365369

366-
def save_matches(self, image: str, matches: Dict[str, np.ndarray]) -> None:
370+
def save_matches(self, image: str, matches: Dict[str, NDArray]) -> None:
367371
self.io_handler.mkdir_p(self._matches_path())
368372

369373
with BytesIO() as buffer:
@@ -372,7 +376,7 @@ def save_matches(self, image: str, matches: Dict[str, np.ndarray]) -> None:
372376
with self.io_handler.open_wb(self._matches_file(image)) as fw:
373377
fw.write(buffer.getvalue())
374378

375-
def find_matches(self, im1: str, im2: str) -> np.ndarray:
379+
def find_matches(self, im1: str, im2: str) -> NDArray:
376380
if self.matches_exists(im1):
377381
im1_matches = self.load_matches(im1)
378382
if im2 in im1_matches:
@@ -422,7 +426,7 @@ def save_reconstruction(
422426
self,
423427
reconstruction: List[types.Reconstruction],
424428
filename: Optional[str] = None,
425-
minify=False,
429+
minify: bool = False,
426430
) -> None:
427431
with self.io_handler.open_wt(self._reconstruction_file(filename)) as fout:
428432
io.json_dump(io.reconstructions_to_json(reconstruction), fout, minify)
@@ -628,11 +632,11 @@ def save_ground_control_points(
628632
with self.io_handler.open_wt(self._ground_control_points_file()) as fout:
629633
io.write_ground_control_points(points, fout)
630634

631-
def image_as_array(self, image: str) -> np.ndarray:
635+
def image_as_array(self, image: str) -> NDArray:
632636
logger.warning("image_as_array() is deprecated. Use load_image() instead.")
633637
return self.load_image(image)
634638

635-
def mask_as_array(self, image: str) -> Optional[np.ndarray]:
639+
def mask_as_array(self, image: str) -> Optional[NDArray]:
636640
logger.warning("mask_as_array() is deprecated. Use load_mask() instead.")
637641
return self.load_mask(image)
638642

@@ -707,18 +711,20 @@ class UndistortedDataSet:
707711
base: DataSetBase
708712
config: Dict[str, Any] = {}
709713
data_path: str
714+
io_handler: io.IoFilesystemBase = io.IoFilesystemDefault()
710715

711716
def __init__(
712717
self,
713718
base_dataset: DataSetBase,
714719
undistorted_data_path: str,
715-
io_handler=io.IoFilesystemDefault,
720+
io_handler: Optional[io.IoFilesystemBase] = None,
716721
) -> None:
717722
"""Init dataset associated to a folder."""
718723
self.base = base_dataset
719724
self.config = self.base.config
720725
self.data_path = undistorted_data_path
721-
self.io_handler = io_handler
726+
if io_handler is not None:
727+
self.io_handler = io_handler
722728

723729
def load_undistorted_shot_ids(self) -> Dict[str, List[str]]:
724730
filename = os.path.join(self.data_path, "undistorted_shot_ids.json")
@@ -738,11 +744,11 @@ def _undistorted_image_file(self, image: str) -> str:
738744
"""Path of undistorted version of an image."""
739745
return os.path.join(self._undistorted_image_path(), image)
740746

741-
def load_undistorted_image(self, image: str) -> np.ndarray:
747+
def load_undistorted_image(self, image: str) -> NDArray:
742748
"""Load undistorted image pixels as a numpy array."""
743749
return self.io_handler.imread(self._undistorted_image_file(image))
744750

745-
def save_undistorted_image(self, image: str, array: np.ndarray) -> None:
751+
def save_undistorted_image(self, image: str, array: NDArray) -> None:
746752
"""Save undistorted image pixels."""
747753
self.io_handler.mkdir_p(self._undistorted_image_path())
748754
self.io_handler.imwrite(self._undistorted_image_file(image), array)
@@ -762,13 +768,13 @@ def undistorted_mask_exists(self, image: str) -> bool:
762768
"""Check if the undistorted mask file exists."""
763769
return self.io_handler.isfile(self._undistorted_mask_file(image))
764770

765-
def load_undistorted_mask(self, image: str) -> np.ndarray:
771+
def load_undistorted_mask(self, image: str) -> NDArray:
766772
"""Load undistorted mask pixels as a numpy array."""
767773
return self.io_handler.imread(
768774
self._undistorted_mask_file(image), grayscale=True
769775
)
770776

771-
def save_undistorted_mask(self, image: str, array: np.ndarray) -> None:
777+
def save_undistorted_mask(self, image: str, array: NDArray) -> None:
772778
"""Save the undistorted image mask."""
773779
self.io_handler.mkdir_p(self._undistorted_mask_path())
774780
self.io_handler.imwrite(self._undistorted_mask_file(image), array)
@@ -784,7 +790,7 @@ def undistorted_segmentation_exists(self, image: str) -> bool:
784790
"""Check if the undistorted segmentation file exists."""
785791
return self.io_handler.isfile(self._undistorted_segmentation_file(image))
786792

787-
def load_undistorted_segmentation(self, image: str) -> np.ndarray:
793+
def load_undistorted_segmentation(self, image: str) -> NDArray:
788794
"""Load an undistorted image segmentation."""
789795
segmentation_file = self._undistorted_segmentation_file(image)
790796
with self.io_handler.open_rb(segmentation_file) as fp:
@@ -804,12 +810,12 @@ def load_undistorted_segmentation(self, image: str) -> np.ndarray:
804810
else:
805811
raise IndexError
806812

807-
def save_undistorted_segmentation(self, image: str, array: np.ndarray) -> None:
813+
def save_undistorted_segmentation(self, image: str, array: NDArray) -> None:
808814
"""Save the undistorted image segmentation."""
809815
self.io_handler.mkdir_p(self._undistorted_segmentation_path())
810816
self.io_handler.imwrite(self._undistorted_segmentation_file(image), array)
811817

812-
def load_undistorted_segmentation_mask(self, image: str) -> Optional[np.ndarray]:
818+
def load_undistorted_segmentation_mask(self, image: str) -> Optional[NDArray]:
813819
"""Build a mask from the undistorted segmentation.
814820
815821
The mask is non-zero only for pixels with segmentation
@@ -828,7 +834,7 @@ def load_undistorted_segmentation_mask(self, image: str) -> Optional[np.ndarray]
828834

829835
return masking.mask_from_segmentation(segmentation, ignore_values)
830836

831-
def load_undistorted_combined_mask(self, image: str) -> Optional[np.ndarray]:
837+
def load_undistorted_combined_mask(self, image: str) -> Optional[NDArray]:
832838
"""Combine undistorted binary mask with segmentation mask.
833839
834840
Return a mask that is non-zero only where the binary
@@ -854,16 +860,16 @@ def point_cloud_file(self, filename: str = "merged.ply") -> str:
854860

855861
def load_point_cloud(
856862
self, filename: str = "merged.ply"
857-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
863+
) -> Tuple[NDArray, NDArray, NDArray, NDArray]:
858864
with self.io_handler.open_rt(self.point_cloud_file(filename)) as fp:
859865
return io.point_cloud_from_ply(fp)
860866

861867
def save_point_cloud(
862868
self,
863-
points: np.ndarray,
864-
normals: np.ndarray,
865-
colors: np.ndarray,
866-
labels: np.ndarray,
869+
points: NDArray,
870+
normals: NDArray,
871+
colors: NDArray,
872+
labels: NDArray,
867873
filename: str = "merged.ply",
868874
) -> None:
869875
self.io_handler.mkdir_p(self._depthmap_path())
@@ -876,11 +882,11 @@ def raw_depthmap_exists(self, image: str) -> bool:
876882
def save_raw_depthmap(
877883
self,
878884
image: str,
879-
depth: np.ndarray,
880-
plane: np.ndarray,
881-
score: np.ndarray,
882-
nghbr: np.ndarray,
883-
nghbrs: np.ndarray,
885+
depth: NDArray,
886+
plane: NDArray,
887+
score: NDArray,
888+
nghbr: NDArray,
889+
nghbrs: NDArray,
884890
) -> None:
885891
self.io_handler.mkdir_p(self._depthmap_path())
886892
filepath = self.depthmap_file(image, "raw.npz")
@@ -891,7 +897,7 @@ def save_raw_depthmap(
891897

892898
def load_raw_depthmap(
893899
self, image: str
894-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
900+
) -> Tuple[NDArray, NDArray, NDArray, NDArray, NDArray]:
895901
with self.io_handler.open_rb(self.depthmap_file(image, "raw.npz")) as f:
896902
o = np.load(f)
897903
return o["depth"], o["plane"], o["score"], o["nghbr"], o["nghbrs"]
@@ -900,16 +906,14 @@ def clean_depthmap_exists(self, image: str) -> bool:
900906
return self.io_handler.isfile(self.depthmap_file(image, "clean.npz"))
901907

902908
def save_clean_depthmap(
903-
self, image: str, depth: np.ndarray, plane: np.ndarray, score: np.ndarray
909+
self, image: str, depth: NDArray, plane: NDArray, score: NDArray
904910
) -> None:
905911
self.io_handler.mkdir_p(self._depthmap_path())
906912
filepath = self.depthmap_file(image, "clean.npz")
907913
with self.io_handler.open_wb(filepath) as f:
908914
np.savez_compressed(f, depth=depth, plane=plane, score=score)
909915

910-
def load_clean_depthmap(
911-
self, image: str
912-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
916+
def load_clean_depthmap(self, image: str) -> Tuple[NDArray, NDArray, NDArray]:
913917
with self.io_handler.open_rb(self.depthmap_file(image, "clean.npz")) as f:
914918
o = np.load(f)
915919
return o["depth"], o["plane"], o["score"]
@@ -920,10 +924,10 @@ def pruned_depthmap_exists(self, image: str) -> bool:
920924
def save_pruned_depthmap(
921925
self,
922926
image: str,
923-
points: np.ndarray,
924-
normals: np.ndarray,
925-
colors: np.ndarray,
926-
labels: np.ndarray,
927+
points: NDArray,
928+
normals: NDArray,
929+
colors: NDArray,
930+
labels: NDArray,
927931
) -> None:
928932
self.io_handler.mkdir_p(self._depthmap_path())
929933
filepath = self.depthmap_file(image, "pruned.npz")
@@ -938,7 +942,7 @@ def save_pruned_depthmap(
938942

939943
def load_pruned_depthmap(
940944
self, image: str
941-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
945+
) -> Tuple[NDArray, NDArray, NDArray, NDArray]:
942946
with self.io_handler.open_rb(self.depthmap_file(image, "pruned.npz")) as f:
943947
o = np.load(f)
944948
return (

0 commit comments

Comments
 (0)