File tree 2 files changed +2
-8
lines changed
2 files changed +2
-8
lines changed Original file line number Diff line number Diff line change @@ -172,7 +172,7 @@ Performance on `v1.0` val (trained on `v1.0` train):
172
172
173
173
| R@1 | R@5 | R@10 | MeanR | MRR |
174
174
| ------ | ------ | ------ | ------ | ------ |
175
- | 0.4194 | 0.7345 | 0.8387 | 5.9876 | 0.5650 |
175
+ | 0.4298 | 0.7464 | 0.8491 | 5.4874 | 0.5757 |
176
176
177
177
178
178
Acknowledgements
Original file line number Diff line number Diff line change 3
3
4
4
def get_gt_ranks (ranks , ans_ind ):
5
5
ans_ind = ans_ind .view (- 1 )
6
- num_opts = 100
7
- ranks = ranks .view (- 1 , num_opts )
8
6
gt_ranks = torch .LongTensor (ans_ind .size (0 ))
9
7
for i in range (ans_ind .size (0 )):
10
- gt_binary = torch .zeros (num_opts )
11
- gt_binary [ans_ind [i ]] = 1
12
- sorted_gt = gt_binary .index_select (0 , ranks [i ].sort ()[1 ].cpu ())
13
- gt_rank = (sorted_gt == 1 ).nonzero () + 1
14
- gt_ranks [i ] = int (gt_rank ) # gt_rank is 1x1 LongTensor
8
+ gt_ranks [i ] = int (ranks [i , ans_ind [i ]])
15
9
return gt_ranks
16
10
17
11
You can’t perform that action at this time.
0 commit comments