@@ -45,30 +45,30 @@ def test__prepare_output():
4545 metric = MeanAveragePrecision ()
4646
4747 metric ._type = "binary"
48- scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
48+ scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
4949 assert scores .shape == y .shape == (1 , 120 )
5050
5151 metric ._type = "multiclass"
5252 scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 4 , (5 , 3 , 2 ))))
5353 assert scores .shape == (4 , 30 ) and y .shape == (30 ,)
5454
5555 metric ._type = "multilabel"
56- scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
56+ scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
5757 assert scores .shape == y .shape == (4 , 30 )
5858
5959
6060def test_update ():
6161 metric = MeanAveragePrecision ()
6262 assert len (metric ._y_pred ) == len (metric ._y_true ) == 0
63- metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 )). bool () ))
63+ metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 ))))
6464 assert len (metric ._y_pred ) == len (metric ._y_true ) == 1
6565
6666
6767def test__compute_recall_and_precision ():
6868 m = MeanAveragePrecision ()
6969
7070 scores = torch .rand ((50 ,))
71- y_true = torch .randint (0 , 2 , (50 ,)). bool ()
71+ y_true = torch .randint (0 , 2 , (50 ,))
7272 precision , recall , _ = precision_recall_curve (y_true .numpy (), scores .numpy ())
7373 P = y_true .sum (dim = - 1 )
7474 ignite_recall , ignite_precision = m ._compute_recall_and_precision (y_true , scores , P )
@@ -77,7 +77,7 @@ def test__compute_recall_and_precision():
7777
7878 # When there's no actual positive. Numpy expectedly raises warning.
7979 scores = torch .rand ((50 ,))
80- y_true = torch .zeros ((50 ,)). bool ()
80+ y_true = torch .zeros ((50 ,))
8181 precision , recall , _ = precision_recall_curve (y_true .numpy (), scores .numpy ())
8282 P = torch .tensor (0 )
8383 ignite_recall , ignite_precision = m ._compute_recall_and_precision (y_true , scores , P )
@@ -147,7 +147,7 @@ def test_compute_nonbinary_data(class_mean):
147147
148148 # Multilabel
149149 m = MeanAveragePrecision (is_multilabel = True , class_mean = class_mean )
150- y_true = torch .randint (0 , 2 , (130 , 5 , 2 , 2 )). bool ()
150+ y_true = torch .randint (0 , 2 , (130 , 5 , 2 , 2 ))
151151 m .update ((scores [:50 ], y_true [:50 ]))
152152 m .update ((scores [50 :], y_true [50 :]))
153153 ignite_map = m .compute ().numpy ()
0 commit comments