forked from yuzhTHU/EIC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheic_for_evaluate.py
More file actions
156 lines (139 loc) · 5.52 KB
/
eic_for_evaluate.py
File metadata and controls
156 lines (139 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import re
import os
import json
import yaml
import time
import signal
import socket
import random
import logging
import datetime
import traceback
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from socket import gethostname
from argparse import ArgumentParser
from setproctitle import setproctitle
from src.nd2py import nd2py as nd
from src.nd2py.nd2py.utils import seed_all, init_logger
from src.eic.eic import get_eic
_logger = logging.getLogger('src')
_cache = {}
def rename(name):
if name == 'class':
name = 'klass'
if '-' in name:
name = name.replace('-', '_')
if '.' in name:
name = name.replace('.', '_')
return name
def get_X(dataset):
if dataset in _cache:
X, y = _cache[dataset]
return X.copy(), y.copy()
path = Path('./data/pmlb/datasets/') / dataset / f'{dataset}.tsv.gz'
if not path.exists():
dataset = f'_deprecated_{dataset}'
path = Path('./data/pmlb/datasets/') / dataset / f'{dataset}.tsv.gz'
data = pd.read_csv(path, sep='\t', compression='gzip')
X = {col: data[col].values for col in data.columns}
X = {rename(k): v for k, v in X.items()}
y = X.pop('target')
_cache[dataset] = (X, y)
return X.copy(), y.copy()
def main(args):
df = pd.read_csv(args.data_path, sep='\t', compression='gzip')
loaders = []
for idx, row in df.iterrows():
if row['algorithm'] not in args.algorithms:
continue
# Skip existing EIC
# if isinstance(row['eic'], float) and not np.isnan(row['eic']):
# continue
if row['f'] is None or isinstance(row['f'], float) and np.isnan(row['f']):
continue
loaders.append(idx)
for idx in tqdm(loaders, dynamic_ncols=True):
row = df.iloc[idx]
try:
f = nd.parse(row['f'])
X, y = get_X(row['dataset'])
eic = get_eic(f, X)
df.at[idx, 'eic'] = eic
except Exception as e:
_logger.error(f"Error processing index {idx}, dataset {row['dataset']}: [{type(e).__name__}] {e}: {traceback.format_exc()}")
df.at[idx, 'eic'] = np.nan
except KeyboardInterrupt:
_logger.error("Interrupted by user.")
break
finally:
pass
save_path = Path(args.save_dir) / Path(args.data_path).name
save_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(save_path, sep='\t', index=False)
_logger.info(f'Saved results to {save_path}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--algorithms', type=str, nargs='+', choices=[
# Whitebox
'SNIP', 'NeurSR', 'SR4MDL', 'AFP-FE', 'Operon', 'BSR',
'GP-GOMEA', 'DSR', 'AFP', 'SBP-GP', 'EPLEX', 'GPlearn',
'FEAT', 'AIFeynman2', 'RSRM', 'E2ESR', 'PySR'
# Blackbox
'AFP', 'BSR', 'DSR', 'EPLEX', 'AFP_FE', 'GP-GOMEA',
'FFX', 'ITEA', 'gplearn', 'MLP', 'RandomForest',
'Operon', 'SBP-GP', 'snip', 'KernelRidge', 'AdaBoost',
'sr4mdl', 'XGB', 'LGBM', 'Linear', 'FEAT', 'neuralsr',
'e2esr', 'MRGP', 'AIFeynman',
], default=[
'SNIP', 'NeurSR', 'SR4MDL', 'AFP-FE', 'Operon', 'BSR',
'GP-GOMEA', 'DSR', 'AFP', 'SBP-GP', 'EPLEX', 'GPlearn',
'FEAT', 'AIFeynman2', 'RSRM', 'E2ESR', 'PySR'
'AFP', 'BSR', 'DSR', 'EPLEX', 'AFP_FE', 'GP-GOMEA',
'FFX', 'ITEA', 'gplearn', 'MLP', 'RandomForest',
'Operon', 'SBP-GP', 'snip', 'KernelRidge', 'AdaBoost',
'sr4mdl', 'XGB', 'LGBM', 'Linear', 'FEAT', 'neuralsr',
'e2esr', 'MRGP', 'AIFeynman',
])
parser.add_argument('--data_path', type=str, default='./data/srbench/whitebox_results_full.csv.gz')
parser.add_argument('-n', '--name', type=str, default=None)
parser.add_argument('-s', '--seed', type=int, default=None)
parser.add_argument('--quiet', action='store_true')
parser.add_argument('--keep_name', action='store_true')
parser.add_argument('--skip_existing', action='store_true')
parser.add_argument('--save_dir', type=str, default='./logs/evaluate')
args, unknown = parser.parse_known_args()
if unknown:
_logger.warning(f'unknown args: {unknown}')
if not args.keep_name:
def sanitize_filename(s: str) -> str:
_illegal = re.compile(r'[<>:"/\\|?*\x00-\x1f]') # Windows/一般不允许的字符
s = s or ''
s = s.strip()
s = _illegal.sub('_', s)
s = s.replace(' ', '_')
return s or 'unnamed'
date = datetime.datetime.now()
hostname = socket.gethostname()
yymmdd = date.strftime('%y%m%d')
hhmmss = date.strftime('%H%M%S')
safe_name = sanitize_filename(args.name)
safe_host = sanitize_filename(hostname)
args.name = f"{yymmdd}_{safe_name}_{hhmmss}_{safe_host}"
if args.skip_existing and any(Path(args.save_dir).glob(f'*_{safe_name}_*_*')):
path = next(Path(args.save_dir).glob(f'*_{safe_name}_*_*'))
_logger.warning(f'Existing experiment found for name={safe_name} in {path}, skip it.')
exit(0)
args.save_dir = os.path.join(args.save_dir, args.name)
init_logger('src', args.name, Path(args.save_dir) / 'info.log')
_logger.info(args)
if args.seed is None:
args.seed = random.randint(0, 10000)
setproctitle(f'{args.name}@YuZihan')
def handler(signum, frame): raise KeyboardInterrupt
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
seed_all(args.seed)
main(args)