Skip to content

Commit bc37753

Browse files
committed
update 02_run_sa.py scripts
1 parent 2e67b02 commit bc37753

File tree

3 files changed

+15
-262
lines changed

3 files changed

+15
-262
lines changed

02_run_sa.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,20 @@
66
from os.path import abspath, join, isdir
77
import sys
88

9-
module_path = abspath("nn4dms_nn-extrapolate/code")
9+
# for relative paths in nn4dms code to work properly, we need to set the current working
10+
# directory to the root of the project
11+
# we also need to add the code folder to the system path for imports to work properly
12+
13+
print('Setting working directory to nn4dms root.')
14+
os.chdir('nn4dms_nn-extrapolate')
15+
module_path = abspath("code")
1016
if module_path not in sys.path:
1117
sys.path.append(module_path)
1218

19+
# add relative path to write directory (nn-extrapolate)
20+
nnextrap_root_relpath = '..'
21+
pretrained_dir = "nn-extrapolation-models/pretrained_models"
22+
1323
import design_tools as tools
1424
import pickle
1525
import random
@@ -18,10 +28,11 @@
1828
import yaml
1929
import importlib
2030

31+
2132
AAs = 'ACDEFGHIKLMNPQRSTVWY'
2233

2334
def load_config(config_file):
24-
with open(config_file, 'r') as stream:
35+
with open(join(nnextrap_root_relpath, config_file), 'r') as stream:
2536
try:
2637
return yaml.safe_load(stream)
2738
except yaml.YAMLError as exc:
@@ -40,11 +51,11 @@ def run_simulated_annealing(config):
4051
cool_sched=config['cool_sched'])
4152
print('running optimization...')
4253
best_mut, fitness = sa_optimizer.optimize(seed=config['seed'])
43-
with open(config['export_best_seqs'], 'wb') as f:
54+
with open(join(nnextrap_root_relpath, config['export_best_seqs']), 'wb') as f:
4455
pickle.dump([best_mut, fitness], f)
4556

4657
if config['save_plot_trajectory']:
47-
sa_optimizer.plot_trajectory(savefig_name=config['file_plot_trajectory'])
58+
sa_optimizer.plot_trajectory(savefig_name=join(nnextrap_root_relpath, config['file_plot_trajectory']))
4859

4960

5061
if __name__ == '__main__':

data/design_tools.py

Lines changed: 0 additions & 229 deletions
This file was deleted.

data/seq2fitness_tools_lr.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)