Skip to content

Commit bc7b59e

Browse files
update
1 parent 017bb8c commit bc7b59e

1 file changed

Lines changed: 53 additions & 244 deletions

File tree

cli/mintflow_cli_recover_outputs.py

Lines changed: 53 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,32 @@ def _convert_TrueFalse_to_bool(dict_input):
672672
maxsize_subgraph = max(list_maxsize_subgraph)
673673

674674

675+
# check if the inflow checkpoint is dumped
676+
path_dump_checkpoint = os.path.join(
677+
args.path_output,
678+
'CheckpointAndPredictions'
679+
)
680+
if (not os.path.isdir(path_dump_checkpoint)) or (not os.path.isfile(os.path.join(path_dump_checkpoint, 'inflow_model.pt'))):
681+
raise Exception(
682+
"The file 'CheckpointAndPredictions/inflow_model.pt' was not found in the output path: \n {}".format(args.original_CLI_run_path_output)
683+
)
684+
685+
module_vardist = torch.load(
686+
os.path.join(
687+
path_dump_checkpoint,
688+
'inflow_model.pt'
689+
),
690+
map_location=device
691+
)['module_inflow']
692+
# TODO:HERE1
693+
print("Loaded the mintflow module on device {} from checkpiont {}".format(
694+
device,
695+
os.path.join(path_dump_checkpoint, 'inflow_model.pt')
696+
))
697+
assert False
698+
699+
700+
675701
exec('disent_dict_CTNNC_usage = {}'.format(config_model['CTNCC_usage_moduledisent']))
676702
assert (
677703
config_model['str_mode_headxint_headxspl_headboth_twosep'] in [
@@ -944,218 +970,6 @@ def _convert_TrueFalse_to_bool(dict_input):
944970

945971
assert False
946972

947-
# start a new wandb run to track this script
948-
if config_training['flag_enable_wandb']:
949-
wandb.init(
950-
project=config_training['wandb_project_name'],
951-
name=config_training['wandb_run_name'],
952-
config={
953-
'dd':'dd'
954-
}
955-
)
956-
itrcount_wandbstep = None
957-
958-
paramlist_optim = module_vardist.parameters()
959-
flag_freezeencdec = False
960-
optim_training = torch.optim.Adam(
961-
params=paramlist_optim,
962-
lr=config_training['lr_training']
963-
)
964-
optim_training.flag_freezeencdec = flag_freezeencdec
965-
966-
optim_afterGRLpreds = torch.optim.Adam(
967-
params=list(module_vardist.module_predictor_xbarint2notNCC.parameters()) +\
968-
list(module_vardist.module_predictor_z2notNCC.parameters()) +\
969-
list(module_vardist.module_predictor_xbarint2notbatchID.parameters()) +\
970-
list(module_vardist.module_predictor_xbarspl2notbatchID.parameters()),
971-
lr=config_training['lr_training']
972-
) # the optimizer for the dual functions (i.e. predictor Z2NotNCC, xbarint2NotNCC)
973-
# TODO:NOTE:BUG module_predictor_xbarint2notbatchID and module_predictor_xbarspl2notbatchID had not been included,
974-
975-
# log the inflow module
976-
with open(os.path.join(args.path_output, "log_inflow_module.txt"), 'w') as f:
977-
f.write(str(module_vardist))
978-
979-
980-
if 'dict_measname_to_histmeas' not in globals():
981-
dict_measname_to_histmeas = {}
982-
dict_measname_to_evalpredxspl = {}
983-
total_cnt_epoch = 0
984-
list_coef_anneal = []
985-
986-
# dump the config dictionaries again, so any inconsistency (e.g. due to boolean variables being treated as str) becomes obvious.
987-
tmp_check_unique = [
988-
os.path.split(args.file_config_data_train)[1],
989-
os.path.split(args.file_config_data_test)[1],
990-
os.path.split(args.file_config_model)[1],
991-
os.path.split(args.file_config_training)[1]
992-
]
993-
for u in tmp_check_unique:
994-
if tmp_check_unique.count(u) > 1:
995-
raise Exception(
996-
"In the provided config files the file name '{}' is repeated {} times, although probably in different directories. \n Please avoid this repeatition and try again".format(
997-
u,
998-
tmp_check_unique.count(u)
999-
)
1000-
)
1001-
1002-
try_mkdir(os.path.join(args.path_output, 'ConfigFilesCopiedOver'))
1003-
os.system(
1004-
"cp {} {}".format(
1005-
os.path.abspath(args.file_config_data_train),
1006-
os.path.join(args.path_output, 'ConfigFilesCopiedOver')
1007-
)
1008-
)
1009-
os.system(
1010-
"cp {} {}".format(
1011-
os.path.abspath(args.file_config_data_test),
1012-
os.path.join(args.path_output, 'ConfigFilesCopiedOver')
1013-
)
1014-
)
1015-
os.system(
1016-
"cp {} {}".format(
1017-
os.path.abspath(args.file_config_model),
1018-
os.path.join(args.path_output, 'ConfigFilesCopiedOver')
1019-
)
1020-
)
1021-
os.system(
1022-
"cp {} {}".format(
1023-
os.path.abspath(args.file_config_training),
1024-
os.path.join(args.path_output, 'ConfigFilesCopiedOver')
1025-
)
1026-
)
1027-
1028-
with open(os.path.join(args.path_output, 'ConfigFilesCopiedOver', 'args.yml'), 'w') as f:
1029-
yaml.dump(
1030-
{
1031-
'file_config_data_train':os.path.split(args.file_config_data_train)[1],
1032-
'file_config_data_test':os.path.split(args.file_config_data_test)[1],
1033-
'file_config_model':os.path.split(args.file_config_model)[1],
1034-
'file_config_training':os.path.split(args.file_config_training)[1]
1035-
},
1036-
f,
1037-
default_flow_style=False
1038-
)
1039-
1040-
1041-
t_before_training = time.time()
1042-
1043-
for idx_epoch in range(config_training['num_training_epochs']):
1044-
print("\n\nEpoch {} from {} ================ ".format(
1045-
idx_epoch+1,
1046-
config_training['num_training_epochs']
1047-
))
1048-
# ten_Z, ten_xbarint, ten_CT, ten_NCC, ten_xy_absolute are obtained using all tissues.
1049-
# update the dual functions separately =============
1050-
with torch.no_grad():
1051-
forduals_ten_Z, forduals_ten_CT, forduals_ten_NCC, forduals_ten_BatchEmb, \
1052-
forduals_ten_xbarint, forduals_ten_xy_absolute, forduals_ten_xbarspl = \
1053-
[], [], [], [], [], [], []
1054-
1055-
print(" Getting different embeddings to update the dual functions separately.")
1056-
for idx_sl, sl in enumerate(list_slice.list_slice):
1057-
anal_dict_varname_to_output = module_vardist.eval_on_pygneighloader_dense(
1058-
dl=sl.pyg_dl_test, # this is correct, because all neighbours are to be included (not a subset of neighbours).
1059-
ten_xy_absolute=sl.ten_xy_absolute,
1060-
tqdm_desc="Tissue {}".format(idx_sl)
1061-
)
1062-
1063-
forduals_ten_Z.append(
1064-
torch.tensor(anal_dict_varname_to_output['mu_z'] + 0.0)
1065-
)
1066-
1067-
forduals_ten_CT.append(
1068-
sl.ten_CT + 0.0
1069-
)
1070-
1071-
forduals_ten_NCC.append(
1072-
sl.ten_NCC + 0.0
1073-
)
1074-
1075-
forduals_ten_BatchEmb.append(
1076-
sl.ten_BatchEmb + 0.0
1077-
)
1078-
1079-
forduals_ten_xbarint.append(
1080-
torch.tensor(anal_dict_varname_to_output['muxbar_int'] + 0.0)
1081-
)
1082-
1083-
forduals_ten_xy_absolute.append(
1084-
sl.ten_xy_absolute + 0.0
1085-
)
1086-
1087-
forduals_ten_xbarspl.append(
1088-
torch.tensor(anal_dict_varname_to_output['muxbar_spl'] + 0.0)
1089-
)
1090-
1091-
del anal_dict_varname_to_output
1092-
1093-
forduals_ten_Z = torch.concat(forduals_ten_Z, 0)
1094-
forduals_ten_CT = torch.concat(forduals_ten_CT, 0)
1095-
forduals_ten_NCC = torch.concat(forduals_ten_NCC, 0)
1096-
forduals_ten_BatchEmb = torch.concat(forduals_ten_BatchEmb, 0)
1097-
forduals_ten_xbarint = torch.concat(forduals_ten_xbarint, 0)
1098-
forduals_ten_xbarspl = torch.concat(forduals_ten_xbarspl, 0)
1099-
forduals_ten_xy_absolute = torch.concat(forduals_ten_xy_absolute, 0)
1100-
1101-
module_vardist._trainsep_GradRevPreds(
1102-
optim_gradrevpreds=optim_afterGRLpreds,
1103-
numiters=config_training['numiters_updateduals_seprately_perepoch'],
1104-
ten_Z=forduals_ten_Z,
1105-
ten_CT=forduals_ten_CT,
1106-
ten_NCC=forduals_ten_NCC,
1107-
ten_xbarint=forduals_ten_xbarint,
1108-
ten_BatchEmb=forduals_ten_BatchEmb,
1109-
ten_xbarspl=forduals_ten_xbarspl,
1110-
ten_xy_absolute=forduals_ten_xy_absolute,
1111-
# Note: the arg `ten_xy_absolute` is not internally used, but kept for backward comptbility.
1112-
device=device,
1113-
kwargs_dl={
1114-
'batch_size':config_training['batchsize_updateduals_seprately_perepoch']
1115-
}
1116-
)
1117-
1118-
# gccollect
1119-
torch.cuda.empty_cache()
1120-
gc.collect()
1121-
torch.cuda.empty_cache()
1122-
gc.collect()
1123-
1124-
# train all modules ===============
1125-
itrcount_wandbstep, list_coef_anneal_ = module_vardist.training_epoch(
1126-
flag_lockencdec_duringtraining=False, # unused arg
1127-
dl=[sl.pyg_dl_train for sl in list_slice.list_slice],
1128-
prob_maskknowngenes=0.0, # unused arg
1129-
t_num_steps=config_model['neuralODE_t_num_steps'],
1130-
ten_xy_absolute=[sl.ten_xy_absolute for sl in list_slice.list_slice],
1131-
optim_training=optim_training,
1132-
tensorboard_stepsize_save=config_training['wandb_stepsize_log'],
1133-
itrcount_wandbstep_input=itrcount_wandbstep,
1134-
list_flag_elboloss_imputationloss=[True, False], # unused arg
1135-
coef_loss_closeness_zz=config_model['coef_loss_closeness_zz'],
1136-
coef_loss_closeness_xbarintxbarint=config_model['coef_loss_closeness_xbarintxbarint'],
1137-
coef_loss_closeness_xintxint=config_model['coef_loss_closeness_xintxint'],
1138-
prob_applytfm_affinexy=0.0, # unused arg
1139-
coef_flowmatchingloss=config_model['coef_flowmatchingloss'],
1140-
np_size_factor=[
1141-
np.array(sl.adata.shape[0] * [config_training['val_scppnorm_total']]) for sl in list_slice.list_slice
1142-
],
1143-
numsteps_accumgrad=config_training['numsteps_accumgrad'],
1144-
num_updateseparate_afterGRLs=config_training['num_updateseparate_afterGRLs'],
1145-
flag_verbose=False,
1146-
flag_enable_wandb=config_training['flag_enable_wandb']
1147-
)
1148-
list_coef_anneal = list_coef_anneal + list_coef_anneal_
1149-
total_cnt_epoch += 1
1150-
1151-
1152-
1153-
1154-
if args.flag_verbose:
1155-
print("Training for {} epochs took {} seconds.".format(
1156-
config_training['num_training_epochs'],
1157-
time.time() - t_before_training
1158-
))
1159973

1160974

1161975
# gccollect
@@ -1167,48 +981,43 @@ def _convert_TrueFalse_to_bool(dict_input):
1167981
gc.collect()
1168982
torch.cuda.empty_cache()
1169983
gc.collect()
1170-
time.sleep(config_training['sleeptime_gccollect_aftertraining'])
1171984

1172985

1173-
# load LR-DB and the ones found in the shared gene panel of tissues ===
1174-
df_LRpairs = pd.read_csv("./Files2Use_CLI/df_LRpairs_Armingoletal.txt")
1175-
list_known_LRgenes_inDB = [
1176-
genename
1177-
for colname in ['LigName', 'RecName'] for group in df_LRpairs[colname].tolist() for genename in str(group).split("__")
1178-
]
1179-
list_known_LRgenes_inDB = set(list_known_LRgenes_inDB)
1180-
list_LR = []
1181-
for gene_name in list_slice.list_slice[0].adata.var.index.tolist():
1182-
if gene_name in list_known_LRgenes_inDB:
1183-
list_LR.append(gene_name)
1184-
1185-
if args.flag_verbose:
1186-
print("\n\n Among the {} genes in tissues' gene panels, {} genes were found in the ligand-receptor database.\n\n".format(
1187-
len(list_slice.list_slice[0].adata.var.index.tolist()),
1188-
len(list_LR)
1189-
))
1190-
1191-
# dump the inflow model as well as the inferred latent factors ===
986+
# check if the inflow checkpoint is dumped
1192987
path_dump_checkpoint = os.path.join(
1193988
args.path_output,
1194989
'CheckpointAndPredictions'
1195990
)
1196-
if not os.path.isdir(path_dump_checkpoint):
1197-
os.mkdir(path_dump_checkpoint)
1198-
1199-
# dump the inflow checkpoint
1200-
module_vardist.module_annealing = "NONE" # so it can be dumped.
1201-
module_vardist.module_annealing_decoderXintXspl = "NONE" # so it can be dumped.
1202-
torch.save(
1203-
{
1204-
'module_inflow': module_vardist,
1205-
},
991+
if (not os.path.isdir(path_dump_checkpoint)) or (not os.path.isfile(os.path.join(path_dump_checkpoint, 'inflow_model.pt'))):
992+
raise Exception(
993+
"The file 'CheckpointAndPredictions/inflow_model.pt' was not found in the output path: \n {}".format(args.original_CLI_run_path_output)
994+
)
995+
996+
module_vardist = torch.load(
1206997
os.path.join(
1207998
path_dump_checkpoint,
1208999
'inflow_model.pt'
12091000
),
1210-
pickle_protocol=4
1211-
)
1001+
map_location=device
1002+
)['module_inflow']
1003+
# TODO:HERE2
1004+
1005+
1006+
assert False
1007+
1008+
# # dump the inflow checkpoint
1009+
# module_vardist.module_annealing = "NONE" # so it can be dumped.
1010+
# module_vardist.module_annealing_decoderXintXspl = "NONE" # so it can be dumped.
1011+
# torch.save(
1012+
# {
1013+
# 'module_inflow': module_vardist,
1014+
# },
1015+
# os.path.join(
1016+
# path_dump_checkpoint,
1017+
# 'inflow_model.pt'
1018+
# ),
1019+
# pickle_protocol=4
1020+
# )
12121021

12131022

12141023
# dump predictions per-tissue

0 commit comments

Comments
 (0)