1- # pyre-unsafe
1+ # pyre-strict
22import gzip
33import json
44import logging
88from typing import Any , Dict , IO , List , Optional , Tuple
99
1010import numpy as np
11+ from numpy .typing import NDArray
1112from opensfm import config , features , geo , io , masking , pygeometry , pymap , rig , types
1213from opensfm .dataset_base import DataSetBase
1314from PIL .PngImagePlugin import PngImageFile
@@ -30,14 +31,17 @@ class DataSet(DataSetBase):
3031 """
3132
3233 io_handler : io .IoFilesystemBase = io .IoFilesystemDefault ()
33- config = None
34+ config : Dict [ str , Any ] = {}
3435 image_files : Dict [str , str ] = {}
3536 mask_files : Dict [str , str ] = {}
3637 image_list : List [str ] = []
3738
38- def __init__ (self , data_path : str , io_handler = io .IoFilesystemDefault ) -> None :
39+ def __init__ (
40+ self , data_path : str , io_handler : Optional [io .IoFilesystemBase ] = None
41+ ) -> None :
3942 """Init dataset associated to a folder."""
40- self .io_handler = io_handler
43+ if io_handler is not None :
44+ self .io_handler = io_handler
4145 self .data_path = data_path
4246 self .load_config ()
4347 self .load_image_list ()
@@ -80,7 +84,7 @@ def _image_file(self, image: str) -> str:
8084 """Path to the image file."""
8185 return self .image_files [image ]
8286
83- def open_image_file (self , image : str ) -> IO [Any ]:
87+ def open_image_file (self , image : str ) -> IO [bytes ]:
8488 """Open image file and return file object."""
8589 return self .io_handler .open_rb (self ._image_file (image ))
8690
@@ -90,7 +94,7 @@ def load_image(
9094 unchanged : bool = False ,
9195 anydepth : bool = False ,
9296 grayscale : bool = False ,
93- ) -> np . ndarray :
97+ ) -> NDArray :
9498 """Load image pixels as numpy array.
9599
96100 The array is 3D, indexed by y-coord, x-coord, channel.
@@ -117,7 +121,7 @@ def load_mask_list(self) -> None:
117121 else :
118122 self ._set_mask_path (os .path .join (self .data_path , "masks" ))
119123
120- def load_mask (self , image : str ) -> Optional [np . ndarray ]:
124+ def load_mask (self , image : str ) -> Optional [NDArray ]:
121125 """Load image mask if it exists, otherwise return None."""
122126 if image in self .mask_files :
123127 mask_path = self .mask_files [image ]
@@ -138,7 +142,7 @@ def _instances_path(self) -> str:
138142 def _instances_file (self , image : str ) -> str :
139143 return os .path .join (self ._instances_path (), image + ".png" )
140144
141- def load_instances (self , image : str ) -> Optional [np . ndarray ]:
145+ def load_instances (self , image : str ) -> Optional [NDArray ]:
142146 """Load image instances file if it exists, otherwise return None."""
143147 instances_file = self ._instances_file (image )
144148 if self .io_handler .isfile (instances_file ):
@@ -153,10 +157,10 @@ def _segmentation_path(self) -> str:
153157 def _segmentation_file (self , image : str ) -> str :
154158 return os .path .join (self ._segmentation_path (), image + ".png" )
155159
156- def segmentation_labels (self ) -> List [Any ]:
160+ def segmentation_labels (self ) -> List [Dict [ str , Any ] ]:
157161 return []
158162
159- def load_segmentation (self , image : str ) -> Optional [np . ndarray ]:
163+ def load_segmentation (self , image : str ) -> Optional [NDArray ]:
160164 """Load image segmentation if it exists, otherwise return None."""
161165 segmentation_file = self ._segmentation_file (image )
162166 if self .io_handler .isfile (segmentation_file ):
@@ -318,12 +322,12 @@ def _words_file(self, image: str) -> str:
318322 def words_exist (self , image : str ) -> bool :
319323 return self .io_handler .isfile (self ._words_file (image ))
320324
321- def load_words (self , image : str ) -> np . ndarray :
325+ def load_words (self , image : str ) -> NDArray :
322326 with self .io_handler .open_rb (self ._words_file (image )) as f :
323327 s = np .load (f )
324328 return s ["words" ].astype (np .int32 )
325329
326- def save_words (self , image : str , words : np . ndarray ) -> None :
330+ def save_words (self , image : str , words : NDArray ) -> None :
327331 with self .io_handler .open_wb (self ._words_file (image )) as f :
328332 np .savez_compressed (f , words = words .astype (np .uint16 ))
329333
@@ -338,7 +342,7 @@ def _matches_file(self, image: str) -> str:
338342 def matches_exists (self , image : str ) -> bool :
339343 return self .io_handler .isfile (self ._matches_file (image ))
340344
341- def load_matches (self , image : str ) -> Dict [str , np . ndarray ]:
345+ def load_matches (self , image : str ) -> Dict [str , NDArray ]:
342346 # Prevent pickling of anything except what we strictly need
343347 # as 'pickle.load' is RCE-prone. Will raise on any class other
344348 # than the numpy ones we allow.
@@ -363,7 +367,7 @@ def find_class(self, module, name):
363367 matches = MatchingUnpickler (BytesIO (gzip .decompress (fin .read ()))).load ()
364368 return matches
365369
366- def save_matches (self , image : str , matches : Dict [str , np . ndarray ]) -> None :
370+ def save_matches (self , image : str , matches : Dict [str , NDArray ]) -> None :
367371 self .io_handler .mkdir_p (self ._matches_path ())
368372
369373 with BytesIO () as buffer :
@@ -372,7 +376,7 @@ def save_matches(self, image: str, matches: Dict[str, np.ndarray]) -> None:
372376 with self .io_handler .open_wb (self ._matches_file (image )) as fw :
373377 fw .write (buffer .getvalue ())
374378
375- def find_matches (self , im1 : str , im2 : str ) -> np . ndarray :
379+ def find_matches (self , im1 : str , im2 : str ) -> NDArray :
376380 if self .matches_exists (im1 ):
377381 im1_matches = self .load_matches (im1 )
378382 if im2 in im1_matches :
@@ -422,7 +426,7 @@ def save_reconstruction(
422426 self ,
423427 reconstruction : List [types .Reconstruction ],
424428 filename : Optional [str ] = None ,
425- minify = False ,
429+ minify : bool = False ,
426430 ) -> None :
427431 with self .io_handler .open_wt (self ._reconstruction_file (filename )) as fout :
428432 io .json_dump (io .reconstructions_to_json (reconstruction ), fout , minify )
@@ -628,11 +632,11 @@ def save_ground_control_points(
628632 with self .io_handler .open_wt (self ._ground_control_points_file ()) as fout :
629633 io .write_ground_control_points (points , fout )
630634
631- def image_as_array (self , image : str ) -> np . ndarray :
635+ def image_as_array (self , image : str ) -> NDArray :
632636 logger .warning ("image_as_array() is deprecated. Use load_image() instead." )
633637 return self .load_image (image )
634638
635- def mask_as_array (self , image : str ) -> Optional [np . ndarray ]:
639+ def mask_as_array (self , image : str ) -> Optional [NDArray ]:
636640 logger .warning ("mask_as_array() is deprecated. Use load_mask() instead." )
637641 return self .load_mask (image )
638642
@@ -707,18 +711,20 @@ class UndistortedDataSet:
707711 base : DataSetBase
708712 config : Dict [str , Any ] = {}
709713 data_path : str
714+ io_handler : io .IoFilesystemBase = io .IoFilesystemDefault ()
710715
711716 def __init__ (
712717 self ,
713718 base_dataset : DataSetBase ,
714719 undistorted_data_path : str ,
715- io_handler = io .IoFilesystemDefault ,
720+ io_handler : Optional [ io .IoFilesystemBase ] = None ,
716721 ) -> None :
717722 """Init dataset associated to a folder."""
718723 self .base = base_dataset
719724 self .config = self .base .config
720725 self .data_path = undistorted_data_path
721- self .io_handler = io_handler
726+ if io_handler is not None :
727+ self .io_handler = io_handler
722728
723729 def load_undistorted_shot_ids (self ) -> Dict [str , List [str ]]:
724730 filename = os .path .join (self .data_path , "undistorted_shot_ids.json" )
@@ -738,11 +744,11 @@ def _undistorted_image_file(self, image: str) -> str:
738744 """Path of undistorted version of an image."""
739745 return os .path .join (self ._undistorted_image_path (), image )
740746
741- def load_undistorted_image (self , image : str ) -> np . ndarray :
747+ def load_undistorted_image (self , image : str ) -> NDArray :
742748 """Load undistorted image pixels as a numpy array."""
743749 return self .io_handler .imread (self ._undistorted_image_file (image ))
744750
745- def save_undistorted_image (self , image : str , array : np . ndarray ) -> None :
751+ def save_undistorted_image (self , image : str , array : NDArray ) -> None :
746752 """Save undistorted image pixels."""
747753 self .io_handler .mkdir_p (self ._undistorted_image_path ())
748754 self .io_handler .imwrite (self ._undistorted_image_file (image ), array )
@@ -762,13 +768,13 @@ def undistorted_mask_exists(self, image: str) -> bool:
762768 """Check if the undistorted mask file exists."""
763769 return self .io_handler .isfile (self ._undistorted_mask_file (image ))
764770
765- def load_undistorted_mask (self , image : str ) -> np . ndarray :
771+ def load_undistorted_mask (self , image : str ) -> NDArray :
766772 """Load undistorted mask pixels as a numpy array."""
767773 return self .io_handler .imread (
768774 self ._undistorted_mask_file (image ), grayscale = True
769775 )
770776
771- def save_undistorted_mask (self , image : str , array : np . ndarray ) -> None :
777+ def save_undistorted_mask (self , image : str , array : NDArray ) -> None :
772778 """Save the undistorted image mask."""
773779 self .io_handler .mkdir_p (self ._undistorted_mask_path ())
774780 self .io_handler .imwrite (self ._undistorted_mask_file (image ), array )
@@ -784,7 +790,7 @@ def undistorted_segmentation_exists(self, image: str) -> bool:
784790 """Check if the undistorted segmentation file exists."""
785791 return self .io_handler .isfile (self ._undistorted_segmentation_file (image ))
786792
787- def load_undistorted_segmentation (self , image : str ) -> np . ndarray :
793+ def load_undistorted_segmentation (self , image : str ) -> NDArray :
788794 """Load an undistorted image segmentation."""
789795 segmentation_file = self ._undistorted_segmentation_file (image )
790796 with self .io_handler .open_rb (segmentation_file ) as fp :
@@ -804,12 +810,12 @@ def load_undistorted_segmentation(self, image: str) -> np.ndarray:
804810 else :
805811 raise IndexError
806812
807- def save_undistorted_segmentation (self , image : str , array : np . ndarray ) -> None :
813+ def save_undistorted_segmentation (self , image : str , array : NDArray ) -> None :
808814 """Save the undistorted image segmentation."""
809815 self .io_handler .mkdir_p (self ._undistorted_segmentation_path ())
810816 self .io_handler .imwrite (self ._undistorted_segmentation_file (image ), array )
811817
812- def load_undistorted_segmentation_mask (self , image : str ) -> Optional [np . ndarray ]:
818+ def load_undistorted_segmentation_mask (self , image : str ) -> Optional [NDArray ]:
813819 """Build a mask from the undistorted segmentation.
814820
815821 The mask is non-zero only for pixels with segmentation
@@ -828,7 +834,7 @@ def load_undistorted_segmentation_mask(self, image: str) -> Optional[np.ndarray]
828834
829835 return masking .mask_from_segmentation (segmentation , ignore_values )
830836
831- def load_undistorted_combined_mask (self , image : str ) -> Optional [np . ndarray ]:
837+ def load_undistorted_combined_mask (self , image : str ) -> Optional [NDArray ]:
832838 """Combine undistorted binary mask with segmentation mask.
833839
834840 Return a mask that is non-zero only where the binary
@@ -854,16 +860,16 @@ def point_cloud_file(self, filename: str = "merged.ply") -> str:
854860
855861 def load_point_cloud (
856862 self , filename : str = "merged.ply"
857- ) -> Tuple [np . ndarray , np . ndarray , np . ndarray , np . ndarray ]:
863+ ) -> Tuple [NDArray , NDArray , NDArray , NDArray ]:
858864 with self .io_handler .open_rt (self .point_cloud_file (filename )) as fp :
859865 return io .point_cloud_from_ply (fp )
860866
861867 def save_point_cloud (
862868 self ,
863- points : np . ndarray ,
864- normals : np . ndarray ,
865- colors : np . ndarray ,
866- labels : np . ndarray ,
869+ points : NDArray ,
870+ normals : NDArray ,
871+ colors : NDArray ,
872+ labels : NDArray ,
867873 filename : str = "merged.ply" ,
868874 ) -> None :
869875 self .io_handler .mkdir_p (self ._depthmap_path ())
@@ -876,11 +882,11 @@ def raw_depthmap_exists(self, image: str) -> bool:
876882 def save_raw_depthmap (
877883 self ,
878884 image : str ,
879- depth : np . ndarray ,
880- plane : np . ndarray ,
881- score : np . ndarray ,
882- nghbr : np . ndarray ,
883- nghbrs : np . ndarray ,
885+ depth : NDArray ,
886+ plane : NDArray ,
887+ score : NDArray ,
888+ nghbr : NDArray ,
889+ nghbrs : NDArray ,
884890 ) -> None :
885891 self .io_handler .mkdir_p (self ._depthmap_path ())
886892 filepath = self .depthmap_file (image , "raw.npz" )
@@ -891,7 +897,7 @@ def save_raw_depthmap(
891897
892898 def load_raw_depthmap (
893899 self , image : str
894- ) -> Tuple [np . ndarray , np . ndarray , np . ndarray , np . ndarray , np . ndarray ]:
900+ ) -> Tuple [NDArray , NDArray , NDArray , NDArray , NDArray ]:
895901 with self .io_handler .open_rb (self .depthmap_file (image , "raw.npz" )) as f :
896902 o = np .load (f )
897903 return o ["depth" ], o ["plane" ], o ["score" ], o ["nghbr" ], o ["nghbrs" ]
@@ -900,16 +906,14 @@ def clean_depthmap_exists(self, image: str) -> bool:
900906 return self .io_handler .isfile (self .depthmap_file (image , "clean.npz" ))
901907
902908 def save_clean_depthmap (
903- self , image : str , depth : np . ndarray , plane : np . ndarray , score : np . ndarray
909+ self , image : str , depth : NDArray , plane : NDArray , score : NDArray
904910 ) -> None :
905911 self .io_handler .mkdir_p (self ._depthmap_path ())
906912 filepath = self .depthmap_file (image , "clean.npz" )
907913 with self .io_handler .open_wb (filepath ) as f :
908914 np .savez_compressed (f , depth = depth , plane = plane , score = score )
909915
910- def load_clean_depthmap (
911- self , image : str
912- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
916+ def load_clean_depthmap (self , image : str ) -> Tuple [NDArray , NDArray , NDArray ]:
913917 with self .io_handler .open_rb (self .depthmap_file (image , "clean.npz" )) as f :
914918 o = np .load (f )
915919 return o ["depth" ], o ["plane" ], o ["score" ]
@@ -920,10 +924,10 @@ def pruned_depthmap_exists(self, image: str) -> bool:
920924 def save_pruned_depthmap (
921925 self ,
922926 image : str ,
923- points : np . ndarray ,
924- normals : np . ndarray ,
925- colors : np . ndarray ,
926- labels : np . ndarray ,
927+ points : NDArray ,
928+ normals : NDArray ,
929+ colors : NDArray ,
930+ labels : NDArray ,
927931 ) -> None :
928932 self .io_handler .mkdir_p (self ._depthmap_path ())
929933 filepath = self .depthmap_file (image , "pruned.npz" )
@@ -938,7 +942,7 @@ def save_pruned_depthmap(
938942
939943 def load_pruned_depthmap (
940944 self , image : str
941- ) -> Tuple [np . ndarray , np . ndarray , np . ndarray , np . ndarray ]:
945+ ) -> Tuple [NDArray , NDArray , NDArray , NDArray ]:
942946 with self .io_handler .open_rb (self .depthmap_file (image , "pruned.npz" )) as f :
943947 o = np .load (f )
944948 return (
0 commit comments