2121class ImageSet :
2222 """
2323 Subscriptapble dataset-like class for loading, storing and processing image collections
24-
24+
2525 :param root: Path to project root directory, which contains data/image_corpus/ or data/query catalog
2626 :param base: Build ImageSet on top of image_corpus if True, else on top of query catalog
2727 :param build: Build ImageSet from filesystem instead of using saved version
@@ -30,31 +30,33 @@ class ImageSet:
3030 :param greyscale: Load images in grayscale if True, else use 3-channel RGB
3131 :param normalize: If True, images will be normalized image-wise when loaded from disk
3232 """
33- def __init__ (self ,
34- root : str ,
35- base : bool = True ,
36- build : bool = False ,
37- transform : Callable = None ,
38- compatibility_mode : bool = False ,
39- greyscale : bool = False ,
40- normalize : bool = True ) -> None :
41-
33+
34+ def __init__ (
35+ self ,
36+ root : str ,
37+ base : bool = True ,
38+ build : bool = False ,
39+ transform : Callable = None ,
40+ compatibility_mode : bool = False ,
41+ greyscale : bool = False ,
42+ normalize : bool = True ,
43+ ) -> None :
44+
4245 self .root = root
4346 self .compatibility_mode = compatibility_mode
4447 self .greyscale = greyscale
45- self .colormode = 'L' if greyscale else ' RGB'
48+ self .colormode = "L" if greyscale else " RGB"
4649 self .transform = transform
4750 self .base = base
4851 self .normalize = normalize
49-
52+
5053 if build :
5154 self .embeddings = []
5255 self .data , self .names = self ._build ()
5356 return
54-
57+
5558 self .data = self ._load ()
56-
57-
59+
5860 def _build (self ) -> Tuple [torch .Tensor , str ]:
5961
6062 dirpath = f"{ self .root } /data/{ 'image_corpus' if self .base else 'query' } "
@@ -66,39 +68,41 @@ def _build(self) -> Tuple[torch.Tensor, str]:
6668 # resize into common shape
6769 im = im .convert (self .colormode ).resize ((118 , 143 ))
6870 if self .normalize :
69- im = cv2 .normalize (np .array (im ), None , 0.0 , 1.0 , cv2 .NORM_MINMAX , cv2 .CV_32FC1 )
70- image = np .array (im , dtype = np .float32 )
71- fname = filename .split ('/' )[- 1 ]
71+ im = cv2 .normalize (
72+ np .array (im ), None , 0.0 , 1.0 , cv2 .NORM_MINMAX , cv2 .CV_32FC1
73+ )
74+ image = np .array (im , dtype = np .float32 )
75+ fname = filename .split ("/" )[- 1 ]
7276 data .append (image )
7377 names .append (fname )
7478 return torch .from_numpy (np .array (data )), names
75-
76- def _load (self ) -> Tuple [torch .Tensor , str ]:
77- ...
78-
79- def save (self ) -> None :
80- ...
81-
79+
80+ def _load (self ) -> Tuple [torch .Tensor , str ]: ...
81+
82+ def save (self ) -> None : ...
83+
8284 def build_embeddings (self , model : SiameseNetwork , device : torch .cuda .device = None ):
83-
85+
8486 if device is None :
8587 device = detect_device ()
86-
88+
8789 with torch .no_grad ():
8890 model .eval ()
8991 for img , name in self :
90- img_input = img .transpose (2 ,0 ).transpose (2 ,1 ).to (device ).unsqueeze (0 )
92+ img_input = img .transpose (2 , 0 ).transpose (2 , 1 ).to (device ).unsqueeze (0 )
9193 embedding = model .get_embedding (img_input )
9294 self .embeddings .append ((embedding , name ))
93-
95+
9496 return self
95-
97+
9698 def get_embeddings (self ) -> List [Tuple [torch .Tensor , str ]]:
9799 if self .embeddings is None :
98- raise RuntimeError ('Embedding collection is empty. Run self.build_embeddings() method to build it' )
99-
100+ raise RuntimeError (
101+ "Embedding collection is empty. Run self.build_embeddings() method to build it"
102+ )
103+
100104 return self .embeddings
101-
105+
102106 def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
103107 """
104108 Args:
@@ -118,40 +122,51 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
118122 img = self .transform (img )
119123
120124 return img , name
121-
122-
125+
126+
123127class SearchTree :
124128 """
125129 Wrapper for k-d tree built on image embeddings
126-
130+
127131 :param query_set: instance of base ImageSet with built embedding representation
128132 """
133+
129134 def __init__ (self , query_set : ImageSet ) -> None :
130135 embeddings = query_set .get_embeddings ()
131- self .embeddings = np .concatenate ([x [0 ].cpu ().numpy () for x in embeddings ], axis = 0 )
136+ self .embeddings = np .concatenate (
137+ [x [0 ].cpu ().numpy () for x in embeddings ], axis = 0
138+ )
132139 self .names = np .array ([x [1 ] for x in embeddings ])
133140 self .kdtree = self ._build_kdtree ()
134-
141+
135142 def _build_kdtree (self ) -> KDTree :
136- print (' Building KD-Tree from embeddings' )
143+ print (" Building KD-Tree from embeddings" )
137144 return KDTree (self .embeddings )
138-
139- def query (self , anchors : ImageSet , k : int = 3 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
145+
146+ def query (
147+ self , anchors : ImageSet , k : int = 3
148+ ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
140149 """
141150 Search for k nearest neighbors of provided anchor embeddings
142-
151+
143152 :param anchors: instance of query (reference) ImageSet with built embedding representation
144-
145- :returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
153+
154+ :returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
146155 """
147-
156+
148157 reference = anchors .get_embeddings ()
149- reference_embeddings = np .concatenate ([x [0 ].cpu ().numpy () for x in reference ], axis = 0 )
158+ reference_embeddings = np .concatenate (
159+ [x [0 ].cpu ().numpy () for x in reference ], axis = 0
160+ )
150161 reference_labels = np .array ([x [1 ] for x in reference ])
151-
152- distances , indices = self .kdtree .query (reference_embeddings , k = k , workers = - 1 )
153- return reference_labels , distances , self .embeddings [indices ], self .names [indices ]
154-
162+
163+ distances , indices = self .kdtree .query (reference_embeddings , k = k , workers = - 1 )
164+ return (
165+ reference_labels ,
166+ distances ,
167+ self .embeddings [indices ],
168+ self .names [indices ],
169+ )
170+
155171 def __call__ (self , * args , ** kwargs ) -> Any :
156172 return self .query (* args , ** kwargs )
157-
0 commit comments