-
Notifications
You must be signed in to change notification settings - Fork 69
support evaluation on vehicleid #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
fbec1c5
23598a6
cf4765f
a1043e6
235d3db
06b4a1f
66ef05e
90bccb3
4f7c35c
d7610c8
2c6005a
7acbfab
efc9bfc
4f399ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,9 @@ | |
|
||
import numpy as np | ||
import torch | ||
import torchvision | ||
|
||
from .train import set_random_seed | ||
from openunreid.data import build_test_dataloader | ||
from ..core.metrics.rank import evaluate_rank | ||
from ..core.utils.compute_dist import build_dist | ||
from ..models.utils.dsbn_utils import switch_target_bn | ||
|
@@ -26,11 +27,9 @@ | |
|
||
@torch.no_grad() | ||
def test_reid( | ||
cfg, model, data_loader, query, gallery, dataset_name=None, rank=None, **kwargs | ||
cfg, model, data_loader, query, gallery, dataset_name=None, num=1, rank=None, **kwargs | ||
): | ||
|
||
start_time = time.monotonic() | ||
|
||
if cfg.MODEL.dsbn: | ||
assert ( | ||
dataset_name is not None | ||
|
@@ -47,7 +46,7 @@ def test_reid( | |
|
||
sep = "*******************************" | ||
if dataset_name is not None: | ||
print(f"\n{sep} Start testing {dataset_name} {sep}\n") | ||
print(f"\n{sep} Start testing {dataset_name} {-num} {sep}\n") | ||
|
||
if rank is None: | ||
rank, _, _ = get_dist_info() | ||
|
@@ -78,7 +77,7 @@ def test_reid( | |
|
||
# evaluate with original distance | ||
dist = build_dist(cfg.TEST, query_features, gallery_features) | ||
cmc, map = evaluate_rank(dist, q_pids, g_pids, q_cids, g_cids) | ||
cmc, map = evaluate_rank(cfg, dist, q_pids, g_pids, q_cids, g_cids) | ||
else: | ||
cmc, map = np.empty(50), 0.0 | ||
|
||
|
@@ -98,14 +97,10 @@ def test_reid( | |
# dist_gg = build_dist(cfg, gallery_features, gallery_features) | ||
# final_dist = re_ranking_cpu(dist, dist_qq, dist_gg) | ||
|
||
cmc, map = evaluate_rank(final_dist, q_pids, g_pids, q_cids, g_cids) | ||
cmc, map = evaluate_rank(cfg, final_dist, q_pids, g_pids, q_cids, g_cids) | ||
else: | ||
cmc, map = np.empty(50), 0.0 | ||
|
||
end_time = time.monotonic() | ||
print("Testing time: ", timedelta(seconds=end_time - start_time)) | ||
print(f"\n{sep} Finished testing {sep}\n") | ||
|
||
return cmc, map | ||
|
||
|
||
|
@@ -142,7 +137,7 @@ def val_reid( | |
# evaluate with original distance | ||
if rank == 0: | ||
dist = build_dist(cfg.TEST, features) | ||
cmc, map = evaluate_rank(dist, pids, pids, cids, cids) | ||
cmc, map = evaluate_rank(cfg, dist, pids, pids, cids, cids) | ||
else: | ||
cmc, map = np.empty(50), 0.0 | ||
|
||
|
@@ -207,3 +202,31 @@ def infer_gan( | |
print(f"\n{sep} Finished translating {sep}\n") | ||
|
||
return | ||
|
||
|
||
@torch.no_grad() | ||
def final_test(cfg, model, cmc_topk=(1, 5, 10)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The module name "final_test" is somehow confusing. Maybe you could modify the original "test_reid" into "test_reid_once" and "final_test" into "test_reid". |
||
sep = "*******************************" | ||
start_time = time.monotonic() | ||
|
||
all_cmc = [] | ||
all_mAP = [] | ||
for num in range(cfg.TRAIN.num_repeat): | ||
set_random_seed(num + 1, cfg.TRAIN.deterministic) | ||
test_loaders, queries, galleries = build_test_dataloader(cfg) | ||
for i, (loader, query, gallery) in enumerate(zip(test_loaders, queries, galleries)): | ||
cmc, mAP = test_reid( | ||
cfg, model, loader, query, gallery, dataset_name=cfg.TEST.datasets[i], num=num+1 | ||
) | ||
all_cmc.append(cmc) | ||
all_mAP.append(mAP) | ||
|
||
if cfg.TRAIN.num_repeat != 1: | ||
print("\n ") | ||
print("Average CMC Scores:") | ||
for k in cmc_topk: | ||
print(" top-{:<4}{:12.1%}".format(k, np.mean(all_cmc, axis=0)[k - 1])) | ||
|
||
end_time = time.monotonic() | ||
print("Testing time: ", timedelta(seconds=end_time - start_time)) | ||
print(f"\n{sep} Finished testing {sep}\n") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import torch | ||
|
||
from openunreid.apis import BaseRunner, batch_processor, test_reid, set_random_seed | ||
from openunreid.apis.test import final_test | ||
from openunreid.core.metrics.accuracy import accuracy | ||
from openunreid.core.solvers import build_lr_scheduler, build_optimizer | ||
from openunreid.data import build_test_dataloader, build_train_dataloader | ||
|
@@ -198,21 +199,11 @@ def main(): | |
runner.resume(cfg.work_dir / "model_best.pth") | ||
|
||
# final testing | ||
test_loaders, queries, galleries = build_test_dataloader(cfg) | ||
for i, (loader, query, gallery) in enumerate(zip(test_loaders, queries, galleries)): | ||
|
||
for idx in range(len(runner.model)): | ||
print("==> Test on the no.{} model".format(idx)) | ||
# test_reid() on self.model[idx] will only evaluate the 'mean_net' | ||
# for testing 'net', use self.model[idx].module.net | ||
cmc, mAP = test_reid( | ||
cfg, | ||
runner.model[idx], | ||
loader, | ||
query, | ||
gallery, | ||
dataset_name=cfg.TEST.datasets[i], | ||
) | ||
for idx in range(len(runner.model)): | ||
print("==> Test on the no.{} model".format(idx)) | ||
# test_reid() on self.model[idx] will only evaluate the 'mean_net' | ||
# for testing 'net', use self.model[idx].module.net | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comments need to be modified accordingly. |
||
final_test(cfg, runner.model[idx]) | ||
|
||
# print time | ||
end_time = time.monotonic() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,9 @@ TRAIN: | |
datasets: {'market1501': 'trainval', 'dukemtmcreid': 'trainval'} | ||
unsup_dataset_indexes: [1,] | ||
|
||
# repeated number of evaluation | ||
num_repeat: 10 # 10 only for vehicleid dataset, otherwise 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here should be "1" I guess? |
||
|
||
epochs: 50 | ||
iters: 200 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"current" here is not so good, please specify whether the metric is mAP or CMC here.