11from typing import Callable , cast , Dict , List , Optional , Sequence , Tuple , Union
22
33import torch
4+ from packaging .version import Version
45from typing_extensions import Literal
56
67from ignite .metrics import MetricGroup
910from ignite .metrics .metric import Metric , reinit__is_reduced , sync_all_reduce
1011
1112
13+ _torch_version_lt_113 = Version (torch .__version__ ) < Version ("1.13.0" )
14+
15+
1216def coco_tensor_list_to_dict_list (
1317 output : Tuple [
1418 Union [List [torch .Tensor ], List [Dict [str , torch .Tensor ]]],
@@ -213,7 +217,8 @@ def _compute_recall_and_precision(
213217 Returns:
214218 `(recall, precision)`
215219 """
216- indices = torch .argsort (scores , dim = - 1 , stable = True , descending = True )
220+ kwargs = {} if _torch_version_lt_113 else {"stable" : True }
221+ indices = torch .argsort (scores , descending = True , ** kwargs )
217222 tp = TP [..., indices ]
218223 tp_summation = tp .cumsum (dim = - 1 )
219224 if tp_summation .device .type != "mps" :
@@ -226,7 +231,7 @@ def _compute_recall_and_precision(
226231
227232 recall = tp_summation / y_true_count
228233 predicted_positive = tp_summation + fp_summation
229- precision = tp_summation / torch .where (predicted_positive == 0 , 1 , predicted_positive )
234+ precision = tp_summation / torch .where (predicted_positive == 0 , 1.0 , predicted_positive )
230235
231236 return recall , precision
232237
@@ -258,9 +263,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
258263 if recall .size (- 1 ) != 0
259264 else torch .LongTensor ([], device = self ._device )
260265 )
261- precision_integrand = precision_integrand .take_along_dim (
262- rec_thresh_indices .where (rec_thresh_indices != recall .size (- 1 ), 0 ), dim = - 1
263- ).where (rec_thresh_indices != recall .size (- 1 ), 0 )
266+ recall_mask = rec_thresh_indices != recall .size (- 1 )
267+ precision_integrand = torch .where (
268+ recall_mask ,
269+ precision_integrand .take_along_dim (torch .where (recall_mask , rec_thresh_indices , 0 ), dim = - 1 ),
270+ 0.0 ,
271+ )
264272 return torch .sum (precision_integrand , dim = - 1 ) / len (cast (torch .Tensor , self .rec_thresholds ))
265273
266274 @reinit__is_reduced
@@ -298,6 +306,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
298306 This key is optional.
299307 ========= ================= =================================================
300308 """
309+ kwargs = {} if _torch_version_lt_113 else {"stable" : True }
301310 self ._check_matching_input (output )
302311 for pred , target in zip (* output ):
303312 labels = target ["labels" ]
@@ -312,7 +321,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
312321
313322 # Matching logic of object detection mAP, according to COCO reference implementation.
314323 if len (pred ["labels" ]):
315- best_detections_index = torch .argsort (pred ["scores" ], stable = True , descending = True )
324+ best_detections_index = torch .argsort (pred ["scores" ], descending = True , ** kwargs )
316325 max_best_detections_index = torch .cat (
317326 [
318327 best_detections_index [pred ["labels" ][best_detections_index ] == c ][
0 commit comments