-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_eval.py
More file actions
99 lines (82 loc) · 3.66 KB
/
run_eval.py
File metadata and controls
99 lines (82 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
from pathlib import Path
import torch.distributed as dist
from logging import getLogger
from recbole.data import (
create_dataset,
data_preparation,
)
from recbole.utils import (
init_logger,
get_model,
get_trainer,
init_seed,
set_color,
get_environment,
)
def run_recbole(
model_file,
device='gpu',
write_predictions=None,
test_empty_sequences=False,
test_only_users_with_infos=False,
eval_per_user=False):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Args:
model (str, optional): Model name. Defaults to ``None``.
dataset (str, optional): Dataset name. Defaults to ``None``.
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
"""
import torch
model_file = Path(model_file)
checkpoint = torch.load(model_file, map_location=torch.device(device))
config = checkpoint["config"]
config["test_empty_sequences"] = test_empty_sequences
config["test_only_users_with_infos"] = test_only_users_with_infos
config["eval_args"]["eval_per_user"] = eval_per_user
init_seed(config["seed"], config["reproducibility"])
init_logger(config)
logger = getLogger()
logger.info(config)
dataset = create_dataset(config)
train_data, valid_data, test_data = data_preparation(config, dataset)
logger.info(train_data)
init_seed(config["seed"], config["reproducibility"])
model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
model.load_state_dict(checkpoint["state_dict"])
model.load_other_parameter(checkpoint.get("other_parameter"))
logger.info(model)
# trainer loading and initialization
trainer = get_trainer(config["MODEL_TYPE"], config["model"])(config, model)
# model evaluation
test_result = trainer.evaluate(
test_data, load_best_model=False, show_progress=config["show_progress"], write_predictions=write_predictions, is_final_test_stage= True
)
logger.info(test_result)
environment_tb = get_environment(config)
logger.info(
"The running environment of this training is as follows:\n"
+ environment_tb.draw()
)
logger.info(set_color("test result", "yellow") + f": {test_result}")
result = {
"test_result": test_result,
}
if not config["single_spec"]:
dist.destroy_process_group()
return result # for the single process
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_file", "-m", type=str, default=None, help="saved model")
parser.add_argument("--device", type=str, default='cuda', help="device")
parser.add_argument("--write_predictions", default=None, help="path to pred file")
parser.add_argument("--test_empty_sequences", type=bool, default=False, help="test empty sequences")
parser.add_argument("--test_only_users_with_infos", type=bool, default=False, help="test_only_users_with_infos")
parser.add_argument("--eval_per_user", type=bool, default=False, help="eval per user")
args, _ = parser.parse_known_args()
res = run_recbole(args.model_file, args.device, args.write_predictions, args.test_empty_sequences,
args.test_only_users_with_infos, args.eval_per_user)