forked from ShuchangYe-bib/ProLearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·119 lines (102 loc) · 3.18 KB
/
Copy pathtest.py
File metadata and controls
executable file
·119 lines (102 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# === Standard Library ===
import argparse
import warnings
# === Third-party Libraries ===
import torch
# === Project Modules ===
from models import ProLearn, Prototype
from modules.dataset import MMSegDataset
from modules.dataloader import MMSegDataLoader
from modules.trainer import Trainer
import utils.config as config
# Suppress non-critical warnings for cleaner output
warnings.filterwarnings("ignore")
def get_parser():
"""
Set up argument parser for model testing configuration.
Returns:
argparse.Namespace: Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description="Language-Guided Medical Image Segmentation Inference")
parser.add_argument(
'--config',
default='./config/training.yaml',
type=str,
help='Path to configuration YAML file.'
)
parser.add_argument(
'--v',
default='',
type=str,
help='Checkpoint version suffix (e.g., 1, 2, etc.).'
)
parser.add_argument(
'--num_prototypes',
type=int,
default=None,
help='Number of prototypes per pseudo-label (optional).'
)
parser.add_argument(
'--num_candidate',
type=int,
default=None,
help='Number of candidates used in response generation (optional).'
)
args = parser.parse_args()
assert args.config is not None, "Configuration file must be provided."
return args
def main():
"""
Main function for loading model and performing inference on test set.
"""
# Load configuration and override with CLI arguments
args = get_parser()
cfg = config.load_cfg_from_cfg_file(args.config)
if args.num_prototypes:
cfg.num_prototypes = args.num_prototypes
if args.num_candidate:
cfg.num_candidate = args.num_candidate
# === Initialize prototype encoder and load precomputed features ===
prototype = Prototype(cfg).to(cfg.device)
prototype.load()
# === Load trained segmentation model ===
model = ProLearn(cfg, prototype).to(cfg.device)
ckpt_path = f"{cfg.model_save_path}/{cfg.model_save_filename}"
if args.v:
ckpt_path += f"-v{args.v}"
ckpt_path += ".ckpt"
checkpoint = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=True)
# === Prepare test dataset and dataloader ===
ds_test = MMSegDataset(
ann_path=cfg.ann_path,
root_path=cfg.root_path,
tokenizer=cfg.bert_type,
image_size=cfg.image_size,
mode='test',
lazy=cfg.lazy
)
ds_test.precompute(encoder=prototype.encoder, device=cfg.device)
dl_test = MMSegDataLoader(
ds_test,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers
)
# === Initialize trainer and run inference ===
trainer = Trainer(
model=model,
optimizer=None,
scheduler=None,
early_stopping_patience=None,
train_loader=None,
val_loader=None,
test_loader=dl_test,
model_save_path=None,
model_name=None,
max_epochs=None,
device=cfg.device
)
trainer.test()
if __name__ == "__main__":
main()