@@ -672,6 +672,32 @@ def _convert_TrueFalse_to_bool(dict_input):
672672maxsize_subgraph = max (list_maxsize_subgraph )
673673
674674
675+ # check if the inflow checkpoint is dumped
676+ path_dump_checkpoint = os .path .join (
677+ args .path_output ,
678+ 'CheckpointAndPredictions'
679+ )
680+ if (not os .path .isdir (path_dump_checkpoint )) or (not os .path .isfile (os .path .join (path_dump_checkpoint , 'inflow_model.pt' ))):
681+ raise Exception (
682+ "The file 'CheckpointAndPredictions/inflow_model.pt' was not found in the output path: \n {}" .format (args .original_CLI_run_path_output )
683+ )
684+
685+ module_vardist = torch .load (
686+ os .path .join (
687+ path_dump_checkpoint ,
688+ 'inflow_model.pt'
689+ ),
690+ map_location = device
691+ )['module_inflow' ]
692+ # TODO:HERE1
693+ print ("Loaded the mintflow module on device {} from checkpiont {}" .format (
694+ device ,
695+ os .path .join (path_dump_checkpoint , 'inflow_model.pt' )
696+ ))
697+ assert False
698+
699+
700+
675701exec ('disent_dict_CTNNC_usage = {}' .format (config_model ['CTNCC_usage_moduledisent' ]))
676702assert (
677703 config_model ['str_mode_headxint_headxspl_headboth_twosep' ] in [
@@ -944,218 +970,6 @@ def _convert_TrueFalse_to_bool(dict_input):
944970
945971assert False
946972
947- # start a new wandb run to track this script
948- if config_training ['flag_enable_wandb' ]:
949- wandb .init (
950- project = config_training ['wandb_project_name' ],
951- name = config_training ['wandb_run_name' ],
952- config = {
953- 'dd' :'dd'
954- }
955- )
956- itrcount_wandbstep = None
957-
958- paramlist_optim = module_vardist .parameters ()
959- flag_freezeencdec = False
960- optim_training = torch .optim .Adam (
961- params = paramlist_optim ,
962- lr = config_training ['lr_training' ]
963- )
964- optim_training .flag_freezeencdec = flag_freezeencdec
965-
966- optim_afterGRLpreds = torch .optim .Adam (
967- params = list (module_vardist .module_predictor_xbarint2notNCC .parameters ()) + \
968- list (module_vardist .module_predictor_z2notNCC .parameters ()) + \
969- list (module_vardist .module_predictor_xbarint2notbatchID .parameters ()) + \
970- list (module_vardist .module_predictor_xbarspl2notbatchID .parameters ()),
971- lr = config_training ['lr_training' ]
972- ) # the optimizer for the dual functions (i.e. predictor Z2NotNCC, xbarint2NotNCC)
973- # TODO:NOTE:BUG module_predictor_xbarint2notbatchID and module_predictor_xbarspl2notbatchID had not been included,
974-
975- # log the inflow module
976- with open (os .path .join (args .path_output , "log_inflow_module.txt" ), 'w' ) as f :
977- f .write (str (module_vardist ))
978-
979-
980- if 'dict_measname_to_histmeas' not in globals ():
981- dict_measname_to_histmeas = {}
982- dict_measname_to_evalpredxspl = {}
983- total_cnt_epoch = 0
984- list_coef_anneal = []
985-
986- # dump the config dictionaries again, so any inconsistency (e.g. due to boolean variables being treated as str) becomes obvious.
987- tmp_check_unique = [
988- os .path .split (args .file_config_data_train )[1 ],
989- os .path .split (args .file_config_data_test )[1 ],
990- os .path .split (args .file_config_model )[1 ],
991- os .path .split (args .file_config_training )[1 ]
992- ]
993- for u in tmp_check_unique :
994- if tmp_check_unique .count (u ) > 1 :
995- raise Exception (
996- "In the provided config files the file name '{}' is repeated {} times, although probably in different directories. \n Please avoid this repeatition and try again" .format (
997- u ,
998- tmp_check_unique .count (u )
999- )
1000- )
1001-
1002- try_mkdir (os .path .join (args .path_output , 'ConfigFilesCopiedOver' ))
1003- os .system (
1004- "cp {} {}" .format (
1005- os .path .abspath (args .file_config_data_train ),
1006- os .path .join (args .path_output , 'ConfigFilesCopiedOver' )
1007- )
1008- )
1009- os .system (
1010- "cp {} {}" .format (
1011- os .path .abspath (args .file_config_data_test ),
1012- os .path .join (args .path_output , 'ConfigFilesCopiedOver' )
1013- )
1014- )
1015- os .system (
1016- "cp {} {}" .format (
1017- os .path .abspath (args .file_config_model ),
1018- os .path .join (args .path_output , 'ConfigFilesCopiedOver' )
1019- )
1020- )
1021- os .system (
1022- "cp {} {}" .format (
1023- os .path .abspath (args .file_config_training ),
1024- os .path .join (args .path_output , 'ConfigFilesCopiedOver' )
1025- )
1026- )
1027-
1028- with open (os .path .join (args .path_output , 'ConfigFilesCopiedOver' , 'args.yml' ), 'w' ) as f :
1029- yaml .dump (
1030- {
1031- 'file_config_data_train' :os .path .split (args .file_config_data_train )[1 ],
1032- 'file_config_data_test' :os .path .split (args .file_config_data_test )[1 ],
1033- 'file_config_model' :os .path .split (args .file_config_model )[1 ],
1034- 'file_config_training' :os .path .split (args .file_config_training )[1 ]
1035- },
1036- f ,
1037- default_flow_style = False
1038- )
1039-
1040-
1041- t_before_training = time .time ()
1042-
1043- for idx_epoch in range (config_training ['num_training_epochs' ]):
1044- print ("\n \n Epoch {} from {} ================ " .format (
1045- idx_epoch + 1 ,
1046- config_training ['num_training_epochs' ]
1047- ))
1048- # ten_Z, ten_xbarint, ten_CT, ten_NCC, ten_xy_absolute are obtained using all tissues.
1049- # update the dual functions separately =============
1050- with torch .no_grad ():
1051- forduals_ten_Z , forduals_ten_CT , forduals_ten_NCC , forduals_ten_BatchEmb , \
1052- forduals_ten_xbarint , forduals_ten_xy_absolute , forduals_ten_xbarspl = \
1053- [], [], [], [], [], [], []
1054-
1055- print (" Getting different embeddings to update the dual functions separately." )
1056- for idx_sl , sl in enumerate (list_slice .list_slice ):
1057- anal_dict_varname_to_output = module_vardist .eval_on_pygneighloader_dense (
1058- dl = sl .pyg_dl_test , # this is correct, because all neighbours are to be included (not a subset of neighbours).
1059- ten_xy_absolute = sl .ten_xy_absolute ,
1060- tqdm_desc = "Tissue {}" .format (idx_sl )
1061- )
1062-
1063- forduals_ten_Z .append (
1064- torch .tensor (anal_dict_varname_to_output ['mu_z' ] + 0.0 )
1065- )
1066-
1067- forduals_ten_CT .append (
1068- sl .ten_CT + 0.0
1069- )
1070-
1071- forduals_ten_NCC .append (
1072- sl .ten_NCC + 0.0
1073- )
1074-
1075- forduals_ten_BatchEmb .append (
1076- sl .ten_BatchEmb + 0.0
1077- )
1078-
1079- forduals_ten_xbarint .append (
1080- torch .tensor (anal_dict_varname_to_output ['muxbar_int' ] + 0.0 )
1081- )
1082-
1083- forduals_ten_xy_absolute .append (
1084- sl .ten_xy_absolute + 0.0
1085- )
1086-
1087- forduals_ten_xbarspl .append (
1088- torch .tensor (anal_dict_varname_to_output ['muxbar_spl' ] + 0.0 )
1089- )
1090-
1091- del anal_dict_varname_to_output
1092-
1093- forduals_ten_Z = torch .concat (forduals_ten_Z , 0 )
1094- forduals_ten_CT = torch .concat (forduals_ten_CT , 0 )
1095- forduals_ten_NCC = torch .concat (forduals_ten_NCC , 0 )
1096- forduals_ten_BatchEmb = torch .concat (forduals_ten_BatchEmb , 0 )
1097- forduals_ten_xbarint = torch .concat (forduals_ten_xbarint , 0 )
1098- forduals_ten_xbarspl = torch .concat (forduals_ten_xbarspl , 0 )
1099- forduals_ten_xy_absolute = torch .concat (forduals_ten_xy_absolute , 0 )
1100-
1101- module_vardist ._trainsep_GradRevPreds (
1102- optim_gradrevpreds = optim_afterGRLpreds ,
1103- numiters = config_training ['numiters_updateduals_seprately_perepoch' ],
1104- ten_Z = forduals_ten_Z ,
1105- ten_CT = forduals_ten_CT ,
1106- ten_NCC = forduals_ten_NCC ,
1107- ten_xbarint = forduals_ten_xbarint ,
1108- ten_BatchEmb = forduals_ten_BatchEmb ,
1109- ten_xbarspl = forduals_ten_xbarspl ,
1110- ten_xy_absolute = forduals_ten_xy_absolute ,
1111- # Note: the arg `ten_xy_absolute` is not internally used, but kept for backward comptbility.
1112- device = device ,
1113- kwargs_dl = {
1114- 'batch_size' :config_training ['batchsize_updateduals_seprately_perepoch' ]
1115- }
1116- )
1117-
1118- # gccollect
1119- torch .cuda .empty_cache ()
1120- gc .collect ()
1121- torch .cuda .empty_cache ()
1122- gc .collect ()
1123-
1124- # train all modules ===============
1125- itrcount_wandbstep , list_coef_anneal_ = module_vardist .training_epoch (
1126- flag_lockencdec_duringtraining = False , # unused arg
1127- dl = [sl .pyg_dl_train for sl in list_slice .list_slice ],
1128- prob_maskknowngenes = 0.0 , # unused arg
1129- t_num_steps = config_model ['neuralODE_t_num_steps' ],
1130- ten_xy_absolute = [sl .ten_xy_absolute for sl in list_slice .list_slice ],
1131- optim_training = optim_training ,
1132- tensorboard_stepsize_save = config_training ['wandb_stepsize_log' ],
1133- itrcount_wandbstep_input = itrcount_wandbstep ,
1134- list_flag_elboloss_imputationloss = [True , False ], # unused arg
1135- coef_loss_closeness_zz = config_model ['coef_loss_closeness_zz' ],
1136- coef_loss_closeness_xbarintxbarint = config_model ['coef_loss_closeness_xbarintxbarint' ],
1137- coef_loss_closeness_xintxint = config_model ['coef_loss_closeness_xintxint' ],
1138- prob_applytfm_affinexy = 0.0 , # unused arg
1139- coef_flowmatchingloss = config_model ['coef_flowmatchingloss' ],
1140- np_size_factor = [
1141- np .array (sl .adata .shape [0 ] * [config_training ['val_scppnorm_total' ]]) for sl in list_slice .list_slice
1142- ],
1143- numsteps_accumgrad = config_training ['numsteps_accumgrad' ],
1144- num_updateseparate_afterGRLs = config_training ['num_updateseparate_afterGRLs' ],
1145- flag_verbose = False ,
1146- flag_enable_wandb = config_training ['flag_enable_wandb' ]
1147- )
1148- list_coef_anneal = list_coef_anneal + list_coef_anneal_
1149- total_cnt_epoch += 1
1150-
1151-
1152-
1153-
1154- if args .flag_verbose :
1155- print ("Training for {} epochs took {} seconds." .format (
1156- config_training ['num_training_epochs' ],
1157- time .time () - t_before_training
1158- ))
1159973
1160974
1161975# gccollect
@@ -1167,48 +981,43 @@ def _convert_TrueFalse_to_bool(dict_input):
1167981gc .collect ()
1168982torch .cuda .empty_cache ()
1169983gc .collect ()
1170- time .sleep (config_training ['sleeptime_gccollect_aftertraining' ])
1171984
1172985
1173- # load LR-DB and the ones found in the shared gene panel of tissues ===
1174- df_LRpairs = pd .read_csv ("./Files2Use_CLI/df_LRpairs_Armingoletal.txt" )
1175- list_known_LRgenes_inDB = [
1176- genename
1177- for colname in ['LigName' , 'RecName' ] for group in df_LRpairs [colname ].tolist () for genename in str (group ).split ("__" )
1178- ]
1179- list_known_LRgenes_inDB = set (list_known_LRgenes_inDB )
1180- list_LR = []
1181- for gene_name in list_slice .list_slice [0 ].adata .var .index .tolist ():
1182- if gene_name in list_known_LRgenes_inDB :
1183- list_LR .append (gene_name )
1184-
1185- if args .flag_verbose :
1186- print ("\n \n Among the {} genes in tissues' gene panels, {} genes were found in the ligand-receptor database.\n \n " .format (
1187- len (list_slice .list_slice [0 ].adata .var .index .tolist ()),
1188- len (list_LR )
1189- ))
1190-
1191- # dump the inflow model as well as the inferred latent factors ===
986+ # check if the inflow checkpoint is dumped
1192987path_dump_checkpoint = os .path .join (
1193988 args .path_output ,
1194989 'CheckpointAndPredictions'
1195990)
1196- if not os .path .isdir (path_dump_checkpoint ):
1197- os .mkdir (path_dump_checkpoint )
1198-
1199- # dump the inflow checkpoint
1200- module_vardist .module_annealing = "NONE" # so it can be dumped.
1201- module_vardist .module_annealing_decoderXintXspl = "NONE" # so it can be dumped.
1202- torch .save (
1203- {
1204- 'module_inflow' : module_vardist ,
1205- },
991+ if (not os .path .isdir (path_dump_checkpoint )) or (not os .path .isfile (os .path .join (path_dump_checkpoint , 'inflow_model.pt' ))):
992+ raise Exception (
993+ "The file 'CheckpointAndPredictions/inflow_model.pt' was not found in the output path: \n {}" .format (args .original_CLI_run_path_output )
994+ )
995+
996+ module_vardist = torch .load (
1206997 os .path .join (
1207998 path_dump_checkpoint ,
1208999 'inflow_model.pt'
12091000 ),
1210- pickle_protocol = 4
1211- )
1001+ map_location = device
1002+ )['module_inflow' ]
1003+ # TODO:HERE2
1004+
1005+
1006+ assert False
1007+
1008+ # # dump the inflow checkpoint
1009+ # module_vardist.module_annealing = "NONE" # so it can be dumped.
1010+ # module_vardist.module_annealing_decoderXintXspl = "NONE" # so it can be dumped.
1011+ # torch.save(
1012+ # {
1013+ # 'module_inflow': module_vardist,
1014+ # },
1015+ # os.path.join(
1016+ # path_dump_checkpoint,
1017+ # 'inflow_model.pt'
1018+ # ),
1019+ # pickle_protocol=4
1020+ # )
12121021
12131022
12141023# dump predictions per-tissue
0 commit comments