-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcompute_scores.py
More file actions
107 lines (90 loc) · 4.03 KB
/
compute_scores.py
File metadata and controls
107 lines (90 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
fixed_seed = 1
import random
import numpy as np
random.seed(fixed_seed)
np.random.seed(fixed_seed)
import os
import logging
import argparse
import pandas as pd
from helpers.utils import init_logger, get_model_name
from models.estimators import SEvaluator, TEvaluator, XEvaluator, DREvaluator, DMLEvaluator, IPSWEvaluator, CausalForestEvaluator
from models.estimators import TSEvaluator, DRSEvaluator, DMLSEvaluator, IPSWSEvaluator
from models.scorers import PluginScorer, RScorerEvaluator, MatchingEvaluator
def get_parser():
parser = argparse.ArgumentParser()
# General
parser.add_argument('--results_path', type=str)
parser.add_argument('--scorer_path', type=str)
parser.add_argument('--sf', dest='splits_file', type=str)
parser.add_argument('--iters', type=int, default=-1)
parser.add_argument('-o', type=str, dest='output_path', default='./')
parser.add_argument('--seed', type=int, default=1)
# Estimation
parser.add_argument('--em', dest='estimation_model', type=str, choices=['sl', 'tl', 'tls', 'xl', 'dr', 'drs', 'dml', 'dmls', 'ipsw', 'ipsws', 'two-head', 'cf'], default='sl')
parser.add_argument('--bm', dest='base_model', type=str, choices=['l1', 'l2', 'tr', 'dt', 'rf', 'et', 'kr', 'cb', 'lgbm', 'mlp'], default='l1')
parser.add_argument('--st', type=str, dest='scorer_type', choices=['plugin', 'r_score', 'matching'], default='plugin')
parser.add_argument('--sn', type=str, dest='scorer_name')
return parser
def get_scorer(opt):
if opt.scorer_type == 'plugin':
return PluginScorer(opt)
elif opt.scorer_type == 'r_score':
return RScorerEvaluator(opt)
elif opt.scorer_type == 'matching':
return MatchingEvaluator(opt)
else:
raise ValueError("Unrecognised 'get_scorer' key.")
def get_evaluator(opt):
if opt.estimation_model in ('sl', 'two-head'):
return SEvaluator(opt)
elif opt.estimation_model == 'tl':
return TEvaluator(opt)
elif opt.estimation_model == 'tls':
return TSEvaluator(opt)
elif opt.estimation_model == 'xl':
return XEvaluator(opt)
elif opt.estimation_model == 'dr':
return DREvaluator(opt)
elif opt.estimation_model == 'drs':
return DRSEvaluator(opt)
elif opt.estimation_model == 'dml':
return DMLEvaluator(opt)
elif opt.estimation_model == 'dmls':
return DMLSEvaluator(opt)
elif opt.estimation_model == 'ipsw':
return IPSWEvaluator(opt)
elif opt.estimation_model == 'ipsws':
return IPSWSEvaluator(opt)
elif opt.estimation_model == 'cf':
return CausalForestEvaluator(opt)
else:
raise ValueError("Unrecognised 'get_evaluator' key.")
if __name__ == "__main__":
parser = get_parser()
options = parser.parse_args()
# Check if output folder exists and create if necessary.
if not os.path.isdir(options.output_path):
os.mkdir(options.output_path)
# Initialise the logger (writes simultaneously to a file and the console).
init_logger(options)
logging.debug(options)
# (iters, folds, idx)
splits = np.load(options.splits_file, allow_pickle=True)
n_iters = options.iters if options.iters > 0 else splits.shape[0]
scorer = get_scorer(options)
evaluator = get_evaluator(options)
df_val_all = None
df_test_all = None
# Data iterations
for i in range(n_iters):
# CV iterations
for k, _ in enumerate(splits['train'][i]):
logging.info(f'Iter {i+1}, Fold {k+1}')
df_fold = scorer.score(evaluator, i+1, k+1)
df_val_all = pd.concat([df_val_all, df_fold], ignore_index=True)
df_iter = scorer.score_test(evaluator, i+1)
df_test_all = pd.concat([df_test_all, df_iter], ignore_index=True)
model_name = get_model_name(options)
df_val_all.to_csv(os.path.join(options.output_path, f'{model_name}_{options.scorer_type}_{options.scorer_name}.csv'), index=False)
df_test_all.to_csv(os.path.join(options.output_path, f'{model_name}_{options.scorer_type}_{options.scorer_name}_test.csv'), index=False)