Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
7a65fa4
update
amirakbarnejad Jun 2, 2025
bed2367
update
amirakbarnejad Jun 2, 2025
fcc36cf
update
amirakbarnejad Jun 2, 2025
94860fa
update
amirakbarnejad Jun 2, 2025
e8d187f
update
amirakbarnejad Jun 2, 2025
b368499
update
amirakbarnejad Jun 2, 2025
eedb7fb
update
amirakbarnejad Jun 2, 2025
24bd68a
update
amirakbarnejad Jun 2, 2025
26f99ec
update
amirakbarnejad Jun 2, 2025
ddbf12e
update
amirakbarnejad Jun 2, 2025
96b9167
update
amirakbarnejad Jun 2, 2025
b13ffad
update
amirakbarnejad Jun 2, 2025
6fda3fe
update
amirakbarnejad Jun 2, 2025
f301fd3
update
amirakbarnejad Jun 2, 2025
6805a16
update
amirakbarnejad Jun 2, 2025
a181f41
update
amirakbarnejad Jun 2, 2025
3d9911b
update
amirakbarnejad Jun 2, 2025
3271c8d
update
amirakbarnejad Jun 2, 2025
55fd7ce
update
amirakbarnejad Jun 2, 2025
205873d
update
amirakbarnejad Jun 2, 2025
558cb11
update
amirakbarnejad Jun 2, 2025
5e6eabc
update
amirakbarnejad Jun 2, 2025
d651615
update
amirakbarnejad Jun 2, 2025
017bcbd
update
amirakbarnejad Jun 2, 2025
b62821d
update
amirakbarnejad Jun 2, 2025
9f1b0d5
update
amirakbarnejad Jun 3, 2025
8625b5e
update
amirakbarnejad Jun 3, 2025
bf6ce70
update
amirakbarnejad Jun 3, 2025
9fd98a4
update
amirakbarnejad Jun 4, 2025
b48bee0
update
amirakbarnejad Jun 4, 2025
cb40221
update
amirakbarnejad Jun 4, 2025
2075558
update
amirakbarnejad Jun 4, 2025
a01fbc6
update
amirakbarnejad Jun 4, 2025
26c8436
update
amirakbarnejad Jun 4, 2025
d91c1ae
update
amirakbarnejad Jun 4, 2025
3cde1d2
update
amirakbarnejad Jun 4, 2025
48f1108
update
amirakbarnejad Jun 4, 2025
e834fcc
update
amirakbarnejad Jun 4, 2025
d6f4639
update
amirakbarnejad Jun 4, 2025
0aa3170
update
amirakbarnejad Jun 4, 2025
dc9ee57
update
amirakbarnejad Jun 4, 2025
d471c25
update
amirakbarnejad Jun 4, 2025
d47c76b
update
amirakbarnejad Jun 4, 2025
82ff0b9
update
amirakbarnejad Jun 4, 2025
6bfb964
update
amirakbarnejad Jun 4, 2025
65ee051
update
amirakbarnejad Jun 4, 2025
c569ce8
update
amirakbarnejad Jun 7, 2025
44e8a9d
update
amirakbarnejad Jun 7, 2025
a59e38f
update
amirakbarnejad Jun 7, 2025
95292b7
update
amirakbarnejad Jun 7, 2025
1052189
update
amirakbarnejad Jun 7, 2025
d998dec
update
amirakbarnejad Jun 7, 2025
36e9130
update
amirakbarnejad Jun 7, 2025
5a64e3c
update
amirakbarnejad Jun 7, 2025
74b74f0
update
amirakbarnejad Jun 7, 2025
d36fe92
update
amirakbarnejad Jun 7, 2025
e247e3e
update
amirakbarnejad Jun 7, 2025
f6841c6
update
amirakbarnejad Jun 7, 2025
23a565b
update
amirakbarnejad Jun 7, 2025
b0f2d18
update
amirakbarnejad Jun 7, 2025
1a88788
update
amirakbarnejad Jun 7, 2025
2721828
update
amirakbarnejad Jun 7, 2025
5ab202f
update
amirakbarnejad Jun 7, 2025
23d6d63
update
amirakbarnejad Jun 7, 2025
a745bc5
update
amirakbarnejad Jun 7, 2025
63c45cd
update
amirakbarnejad Jun 7, 2025
22296ec
update
amirakbarnejad Jun 7, 2025
4eaa000
update
amirakbarnejad Jun 7, 2025
a97aade
update
amirakbarnejad Jun 7, 2025
7ceb536
update
amirakbarnejad Jun 7, 2025
a6a93eb
update
amirakbarnejad Jun 7, 2025
9cde70d
update
amirakbarnejad Jun 7, 2025
1a9cafa
update
amirakbarnejad Jun 7, 2025
3c9cc9c
update
amirakbarnejad Jun 7, 2025
7fefa23
update
amirakbarnejad Jun 7, 2025
689c3ab
update
amirakbarnejad Jun 7, 2025
6fc4225
update
amirakbarnejad Jun 7, 2025
84ab0f3
update
amirakbarnejad Jun 8, 2025
3af69dc
update
amirakbarnejad Jun 8, 2025
d65c093
update
amirakbarnejad Jun 8, 2025
348e43c
update
amirakbarnejad Jun 8, 2025
6727084
update
amirakbarnejad Jun 8, 2025
36a2d35
update
amirakbarnejad Jun 8, 2025
f2c78b8
update
amirakbarnejad Jun 8, 2025
d83519c
update
amirakbarnejad Jun 8, 2025
6f4d5eb
update
amirakbarnejad Jun 8, 2025
49552a4
update
amirakbarnejad Jun 8, 2025
b80442c
update
amirakbarnejad Jun 8, 2025
05bd40b
update
amirakbarnejad Jun 8, 2025
020c814
update
amirakbarnejad Jun 8, 2025
bf165c7
update
amirakbarnejad Jun 8, 2025
71986aa
update
amirakbarnejad Jun 8, 2025
7438bc8
update
amirakbarnejad Jun 8, 2025
9de9ebf
update
amirakbarnejad Jun 8, 2025
10c69b1
update
amirakbarnejad Jun 8, 2025
8fd0287
update
amirakbarnejad Jun 8, 2025
87b237a
update
amirakbarnejad Jun 8, 2025
2bf9d57
update
amirakbarnejad Jun 8, 2025
37a3bdb
update
amirakbarnejad Jun 8, 2025
3663d2c
update
amirakbarnejad Jun 8, 2025
131759a
update
amirakbarnejad Jun 8, 2025
8cd232a
update
amirakbarnejad Jun 8, 2025
127d416
update
amirakbarnejad Jun 8, 2025
28fbe67
update
amirakbarnejad Jun 8, 2025
156599a
update
amirakbarnejad Jun 8, 2025
16a5ab3
update
amirakbarnejad Jun 8, 2025
e84de12
update
amirakbarnejad Jun 8, 2025
9774a56
update
amirakbarnejad Jun 8, 2025
17e5015
update
amirakbarnejad Jun 8, 2025
e6ef7b4
update
amirakbarnejad Jun 8, 2025
f9e5c94
update
amirakbarnejad Jun 8, 2025
bfc4d5c
update
amirakbarnejad Jun 8, 2025
8ebbc50
update
amirakbarnejad Jun 8, 2025
eab7893
update
amirakbarnejad Jun 9, 2025
eb490f6
update
amirakbarnejad Jun 9, 2025
48b436b
update
amirakbarnejad Jun 9, 2025
5cf3ee3
update
amirakbarnejad Jun 9, 2025
837783f
update
amirakbarnejad Jun 9, 2025
df3e0fd
update
amirakbarnejad Jun 9, 2025
2863396
update
amirakbarnejad Jun 9, 2025
94c2788
update
amirakbarnejad Jun 9, 2025
0e6e54d
update
amirakbarnejad Jun 9, 2025
2eeef08
update
amirakbarnejad Jun 9, 2025
9bcbca7
update
amirakbarnejad Jun 9, 2025
a2a3136
update
amirakbarnejad Jun 9, 2025
9b97140
update
amirakbarnejad Jun 9, 2025
9ed3476
update
amirakbarnejad Jun 9, 2025
6a89f1b
update
amirakbarnejad Jun 9, 2025
0acfa1f
update
amirakbarnejad Jun 9, 2025
e06bc6d
update
amirakbarnejad Jun 9, 2025
cee1db8
update
amirakbarnejad Jun 9, 2025
5bad118
update
amirakbarnejad Jun 9, 2025
44df2bc
update
amirakbarnejad Jun 9, 2025
39a4f42
update
amirakbarnejad Jun 9, 2025
efbc77f
update
amirakbarnejad Jun 9, 2025
5feb168
update
amirakbarnejad Jun 9, 2025
d56c464
update
amirakbarnejad Jun 10, 2025
85633ed
update
amirakbarnejad Jun 10, 2025
4633e48
update
amirakbarnejad Jun 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@

"""
After running the CLI via `python mintflow_cli.py ... ` the code may crash due to, e.g., memory issue before some results are dumpued in the specified `path_output`.
As long as
- the checkpoint file is available in `path_output/CheckpointAndPredictions`
- and the old config files are available in `path_output/ConfigFilesCopiedOver` (which should be normally true)
The current script should be able to recover the outputs as usual.
The current script goes over the checkpoints in `CheckpointsAndPredictions` in the output path and creates the predictions as well as
"""

#use inflow or inflow_synth
STR_INFLOW_OR_INFLOW_SYNTH = "inflow" # in ['inflow', 'inflow_synth']
assert(
STR_INFLOW_OR_INFLOW_SYNTH == 'inflow' # in ['inflow', 'inflow_synth']
STR_INFLOW_OR_INFLOW_SYNTH == 'inflow' # in ['inflow', 'inflow_synth']
)

import os, sys
Expand Down Expand Up @@ -148,7 +145,15 @@
'--original_CLI_run_path_output',
type=str,
help='The original output path specified when running mintflow CLI.\n' +\
'In other words, the `path_output` passed to the CLI when running `python mintflow_cli.py ....`.'
'In other words, the `path_output` passed to the CLI when running `python mintflow_cli_train_model.py ....`.'
)

parser.add_argument(
'--anndata_varkey_gene_ensembleID',
type=str,
help='Optional:the column in `adata.var` that specifies gene ens IDs.\n' +\
'To be used when evaluating the disentanglement based on gene scores.\n' +\
'If set to None, this script assumes ensemble IDs are not available and instead uses gene names in `adata.var.index`'
)

parser.add_argument(
Expand All @@ -157,6 +162,13 @@
help="A string in ['True', 'False']"
)

# parser.add_argument(
# '--flag_dump_anndata_objects',
# type=str,
# help="If set to True, this scripot dumps the anndata objects with MintFlow predictions in its .obsm field. A string in ['True', 'False']\n"+\
# "If you face out of memory issues, you can set this argument to 'False'."
# )

parser.add_argument(
'--flag_verbose',
type=str,
Expand All @@ -181,6 +193,10 @@ def try_mkdir(path_in):
assert args.flag_use_cuda in ['True', 'False']
args.flag_use_cuda = (args.flag_use_cuda == 'True')

# assert isinstance(args.flag_dump_anndata_objects, str)
# assert args.flag_dump_anndata_objects in ['True', 'False']
# args.flag_dump_anndata_objects = (args.flag_dump_anndata_objects == 'True')

# find the mapping of the config file names (important when the config files have been modified and are potentially irrelevant names)
with open(
os.path.join(
Expand Down Expand Up @@ -230,43 +246,6 @@ def try_mkdir(path_in):
)


# TODO: parse other config files ===


# check if the provided anndata-s share the same gene panel and they all contain count values ===========
fname_adata0, adata0 = config_data_train[0]['file'], sc.read_h5ad(config_data_train[0]['file'])
for config_temp in config_data_train + config_data_test:
if args.flag_verbose:
print("checking if {} and {} share the same gene panel".format(
fname_adata0,
config_temp['file']
))

fname_adata_temp, adata_temp = config_temp['file'], sc.read_h5ad(config_temp['file'])
if adata_temp.var_names.tolist() != adata0.var_names.tolist():
raise Exception(
"Anndata-s {} and {} do not have the same gene panel.".format(
fname_adata0,
fname_adata_temp
)
)

if not sc._utils.check_nonnegative_integers(adata_temp.X): # grabbed from https://github.com/scverse/scanpy/blob/0cfd0224f8b0a90317b0f1a61562f62eea2c2927/src/scanpy/preprocessing/_highly_variable_genes.py#L74
raise Exception(
"Inflow requires count data, but the anndata in {} seems to have non-count values in adata.X".format(
fname_adata_temp
)
)
else:
if args.flag_verbose:
print(" also checked that the 2nd anndata has count data in adata.X")

del fname_adata_temp, adata_temp
gc.collect()

del fname_adata0, adata0, config_temp
gc.collect()

# set device ===
if args.flag_use_cuda: #config_training['flag_use_GPU']:
if torch.cuda.is_available():
Expand Down Expand Up @@ -357,7 +336,7 @@ def _convert_TrueFalse_to_bool(dict_input):
)

if args.flag_verbose:
print("\n\ncreated list_slice for training.")
print("\n\nLoaded the training list of tissue.")
for sl in list_slice.list_slice:
print("Tissue {} --> {} cells".format(
set(sl.adata.obs[sl.dict_obskey['sliceid_to_checkUnique']]),
Expand Down Expand Up @@ -436,7 +415,7 @@ def _convert_TrueFalse_to_bool(dict_input):

if args.flag_verbose:
print("\n\n\n")
print("The provided cell types are aggregated/mapped to inflow cell types as follow:")
print("The provided cell types are aggregated/mapped to inflow cell types as follows:")
pprint(list_slice.map_CT_to_inflowCT)
print("\n\n")

Expand All @@ -447,6 +426,7 @@ def _convert_TrueFalse_to_bool(dict_input):
print("\n\n")

# Note: due to the implementation in `utils_multislice.py` the assigned cell type and batchIDs do not vary in different runs.
# TODO: double-check it via the dumped general info in the output path

if args.flag_verbose:
with torch.no_grad():
Expand All @@ -462,178 +442,92 @@ def _convert_TrueFalse_to_bool(dict_input):

# TODO: assert that the 1st tissue is assigned batch ID '0' ===


# check if the inflow checkpoint is dumped
path_dump_checkpoint = os.path.join(
args.original_CLI_run_path_output,
'CheckpointAndPredictions'
)
if (not os.path.isdir(path_dump_checkpoint)) or (not os.path.isfile(os.path.join(path_dump_checkpoint, 'inflow_model.pt'))):
raise Exception(
"The file 'CheckpointAndPredictions/inflow_model.pt' was not found in the output path: \n {}".format(args.original_CLI_run_path_output)
)

module_vardist = torch.load(
# get list of epochs for which the checkpoint is dumped in the output path
list_epochs_dumped = []
for fname_checkpoint in os.listdir(
os.path.join(
path_dump_checkpoint,
'inflow_model.pt'
),
map_location=device
)['module_inflow']

print("Loaded the mintflow module on device {} from checkpiont {}".format(
device,
os.path.join(path_dump_checkpoint, 'inflow_model.pt')
))

torch.cuda.empty_cache()
gc.collect()
args.original_CLI_run_path_output,
'CheckpointAndPredictions'
)
):
flag_ischeckpoint = False
if len(fname_checkpoint) >= len('mintflow_checkpoint_epoch_'):
if fname_checkpoint.endswith(".pt"):
if fname_checkpoint[0:len('mintflow_checkpoint_epoch_')] == 'mintflow_checkpoint_epoch_':
flag_ischeckpoint = True

if flag_ischeckpoint:
list_epochs_dumped.append(
int(fname_checkpoint.split('_')[-1][0:-3])
)

list_epochs_dumped.sort()
print("Checkpoints for the following epochs are dumped in the output path:")
for u in list_epochs_dumped:
print(" {}".format(u))

# dump predictions per-tissue
with torch.no_grad():
for epoch in tqdm(list_epochs_dumped, desc="Computing metrics over different epoch checkpoints"):
# Loop over the testing tissue sections
for idx_sl, sl in enumerate(test_list_slice.list_slice):
print("\n\n")

anal_dict_varname_to_output_slice = module_vardist.eval_on_pygneighloader_dense(
dl=test_list_slice.list_slice[idx_sl].pyg_dl_test,
ten_xy_absolute=test_list_slice.list_slice[idx_sl].ten_xy_absolute,
tqdm_desc="Evaluating on tissue {}".format(idx_sl+1)
# load predictions for sl
dict_content = torch.load(
os.path.join(
args.original_CLI_run_path_output,
'CheckpointAndPredictions',
'Predictions_And_Evaluation_mintflow_checkpoint_epoch_{}'.format(epoch),
'perdictions_tissue_section_{}.pt'.format(idx_sl+1)
),
weights_only=False
)
'''
anal_dict_varname_to_output_slice is a dict with the following keys:
['output_imputer',
'muxint',
'muxspl',
'muxbar_int',
'muxbar_spl',
'mu_sin',
'mu_sout',
'mu_z',
'x_int',
'x_spl']
'''

# remove redundant fields ===
anal_dict_varname_to_output_slice.pop('output_imputer', None)
anal_dict_varname_to_output_slice.pop('x_int', None)
anal_dict_varname_to_output_slice.pop('x_spl', None)


# get pred_Xspl and pred_Xint before row normalisation on adata.X
rowcoef_correct4scppnormtotal = (np.array(sl.adata_before_scppnormalize_total.X.sum(1).tolist()) + 0.0) / (config_training['val_scppnorm_total'] + 0.0)
if len(rowcoef_correct4scppnormtotal.shape) == 1:
rowcoef_correct4scppnormtotal = np.expand_dims(rowcoef_correct4scppnormtotal, -1) # [N x 1]

assert rowcoef_correct4scppnormtotal.shape[0] == sl.adata_before_scppnormalize_total.shape[0]
assert rowcoef_correct4scppnormtotal.shape[1] == 1

anal_dict_varname_to_output_slice['muxint_before_sc_pp_normalize_total'] = anal_dict_varname_to_output_slice['muxint'] * rowcoef_correct4scppnormtotal + 0.0
anal_dict_varname_to_output_slice['muxspl_before_sc_pp_normalize_total'] = anal_dict_varname_to_output_slice['muxspl'] * rowcoef_correct4scppnormtotal + 0.0

'''
Sparsify the following vars
- muxint
- muxspl
- muxint_before_sc_pp_normalize_total
- muxspl_before_sc_pp_normalize_total
-
'''
tmp_mask = test_list_slice.list_slice[idx_sl].adata.X + 0
if issparse(tmp_mask):
tmp_mask = tmp_mask.toarray()
tmp_mask = ((tmp_mask > 0) + 0).astype(int)

for var in [
'muxint',
'muxspl',
'muxint_before_sc_pp_normalize_total',
'muxspl_before_sc_pp_normalize_total'
]:
anal_dict_varname_to_output_slice[var] = coo_matrix(anal_dict_varname_to_output_slice[var] * tmp_mask)

# TODO: modify when sparsification is added inside `eval_on_pygneighloader_dense`

'''
The sparse format may have more 0-s than tmp_mask, so the check below was removed.
if len(anal_dict_varname_to_output_slice[var].data) == tmp_mask.sum():
path_debug_output = os.path.join(
args.path_output,
'DebugInfo'
)
try_mkdir(path_debug_output)

# dump the anndata ===
test_list_slice.list_slice[idx_sl].adata.write(
os.path.join(
path_debug_output,
'adata.h5ad'
)
)

# dump `tmp_mask` ===
with open(os.path.join(path_debug_output, 'tmp_mask.pkl'), 'wb') as f:
pickle.dump(tmp_mask, f)

# dump anal_dict_varname_to_output_slice[var]
with open(os.path.join(path_debug_output, 'var_{}.pkl'.format(var)), 'wb') as f:
pickle.dump(
anal_dict_varname_to_output_slice[var],
f
)

raise Exception(
"Something went wrong when trying to sparsify {}".format(var)
)
'''

gc.collect()
Xint = dict_content['MintFlow_Xint (before_sc_pp_normalize_total)']
Xmic = dict_content['MintFlow_Xmic (before_sc_pp_normalize_total)']

assert Xint.shape[0] == sl.adata.shape[0]
assert Xmic.shape[0] == sl.adata.shape[0]

# dump the predictions
torch.save(
anal_dict_varname_to_output_slice,
os.path.join(path_dump_checkpoint, 'predictions_slice_{}.pt'.format(idx_sl + 1)),
pickle_protocol=4
)

# loop over MCP collections
dict_fnamecoll_to_df_eval = {}
for fname_collection in os.listdir("./Files2Use_CLI/Evaluation/MCC_Predictability_GeneScores/"):
if fname_collection.endswith('.pkl'):
with open(
os.path.join(
"./Files2Use_CLI/Evaluation/MCC_Predictability_GeneScores/",
fname_collection
),
'rb'
) as f:
collection = pickle.load(f)

dict_fnamecoll_to_df_eval[fname_collection] = collection.score_Xmic_Xint(
list_ens_ID=None if(args.anndata_varkey_gene_ensembleID.lower() == 'none') else sl.adata.var[args.anndata_varkey_gene_ensembleID],
list_gene_name=sl.adata.var.index.tolist(),
Xint_before_scppnormalizetotal=Xint.copy(),
Xmic_before_scppnormalizetotal=Xmic.copy()
)

del anal_dict_varname_to_output_slice
gc.collect()
# dump the result
with open(
os.path.join(
args.original_CLI_run_path_output,
'CheckpointAndPredictions',
'Predictions_And_Evaluation_mintflow_checkpoint_epoch_{}'.format(epoch),
'evaluationresult_MCC_predictabiliyt_gene_scores_tissuesection_{}.pkl'.format(idx_sl + 1)
),
'wb'
) as f:
pickle.dump(
dict_fnamecoll_to_df_eval,
f
)


# dump the tissue samples ===
path_dump_training_listtissue = os.path.join(
args.original_CLI_run_path_output,
"TrainingListTissue"
)
path_dump_testing_listtissue = os.path.join(
args.original_CLI_run_path_output,
"TestingListTissue"
)
try_mkdir(path_dump_training_listtissue)
try_mkdir(path_dump_testing_listtissue)

for idx_sl, sl in enumerate(list_slice.list_slice):
# with open(os.path.join(path_dump_training_listtissue, 'tissue_tr_{}.pkl'.format(idx_sl+1)), 'wb') as f:
# pickle.dump(sl, f)

torch.save(
sl,
os.path.join(path_dump_training_listtissue, 'tissue_tr_{}.pt'.format(idx_sl + 1)),
pickle_protocol=4
)

for idx_sl, sl in enumerate(test_list_slice.list_slice):

# with open(os.path.join(path_dump_testing_listtissue, 'tissue_test_{}.pkl'.format(idx_sl+1)), 'wb') as f:
# pickle.dump(sl, f)

torch.save(
sl,
os.path.join(path_dump_testing_listtissue, 'tissue_test_{}.pt'.format(idx_sl + 1)),
pickle_protocol=4
)


print("Finished running the script successfully!")
Expand Down
Loading
Loading