diff --git a/Snakefile b/Snakefile index 6e0b8ae8..b6f52bbd 100644 --- a/Snakefile +++ b/Snakefile @@ -108,8 +108,11 @@ def make_final_input(wildcards): final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-jaccard-heatmap.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms)) if _config.config.analysis_include_evaluation: - final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params)) - + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs)) + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs)) + if _config.config.analysis_include_evaluation_aggregate_algo: + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs)) + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs)) if len(final_input) == 0: # No analysis added yet, so add reconstruction output files if they exist. # (if analysis is specified, these should be implicitly run). @@ -392,23 +395,56 @@ def get_gold_standard_pickle_file(wildcards): parts = wildcards.dataset_gold_standard_pairs.split('-') gs = parts[1] return SEP.join([out_dir, f'{gs}-merged.pickle']) - + # Returns the dataset corresponding to the gold standard pair def get_dataset_label(wildcards): parts = wildcards.dataset_gold_standard_pairs.split('-') dataset = parts[0] return dataset -# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard -rule evaluation: +# Return the dataset pickle file for a specific dataset +def get_dataset_pickle_file(wildcards): + dataset_label = get_dataset_label(wildcards) + return SEP.join([out_dir, f'{dataset_label}-merged.pickle']) + +# Returns ensemble file for each dataset +def collect_ensemble_per_dataset(wildcards): + dataset_label = get_dataset_label(wildcards) + return expand('{out_dir}{sep}{dataset}-ml{sep}ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label) + +# Run precision-recall curves for each ensemble pathway within a dataset evaluated against its corresponding gold standard +rule evaluation_ensemble_pr_curve: input: gold_standard_file = get_gold_standard_pickle_file, - pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label), - output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"]) + dataset_file = get_dataset_pickle_file, + ensemble_file = collect_ensemble_per_dataset + output: + pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.png']), + pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.txt']), + run: + node_table = Evaluation.from_file(input.gold_standard_file).node_table + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file) + Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file) + +# Returns list of algorithm specific ensemble files per dataset +def collect_ensemble_per_algo_per_dataset(wildcards): + dataset_label = get_dataset_label(wildcards) + return expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label, algorithm=algorithms) + +# Run precision-recall curves for each algorithm's ensemble pathway within a dataset evaluated against its corresponding gold standard +rule evaluation_per_algo_ensemble_pr_curve: + input: + gold_standard_file = get_gold_standard_pickle_file, + dataset_file = get_dataset_pickle_file, + ensemble_files = collect_ensemble_per_algo_per_dataset + output: + pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.png']), + pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.txt']), run: node_table = Evaluation.from_file(input.gold_standard_file).node_table - Evaluation.precision(input.pathways, node_table, output.eval_file) + node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file) + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file) # Remove the output directory rule clean: - shell: f'rm -rf {out_dir}' \ No newline at end of file + shell: f'rm -rf {out_dir}' diff --git a/config/egfr.yaml b/config/egfr.yaml index b8c5138b..7604d9e4 100644 --- a/config/egfr.yaml +++ b/config/egfr.yaml @@ -13,6 +13,7 @@ algorithms: k: - 10 - 20 + - 70 - name: omicsintegrator1 params: include: true @@ -55,6 +56,16 @@ algorithms: - 3 rand_restarts: - 10 + run2: + local_search: + - 'No' + max_path_length: + - 2 + rand_restarts: + - 10 + - name: allpairs + params: + include: true - name: domino params: include: true @@ -63,6 +74,24 @@ algorithms: - 0.3 module_threshold: - 0.05 + - name: mincostflow + params: + include: true + run1: + capacity: + - 15 + flow: + - 80 + run2: + capacity: + - 1 + flow: + - 6 + run3: + capacity: + - 5 + flow: + - 60 datasets: - data_dir: input edge_files: @@ -71,6 +100,13 @@ datasets: node_files: - tps-egfr-prizes.txt other_files: [] +gold_standards: + - label: gs_egfr + node_files: + - gs-egfr.txt + data_dir: input + dataset_labels: + - tps_egfr reconstruction_settings: locations: reconstruction_dir: output/egfr @@ -83,6 +119,9 @@ analysis: summary: include: true ml: - include: false + include: true + aggregate_per_algorithm: true + labels: true evaluation: - include: false + include: true + aggregate_per_algorithm: true diff --git a/spras/evaluation.py b/spras/evaluation.py index 21d0372e..6063ac24 100644 --- a/spras/evaluation.py +++ b/spras/evaluation.py @@ -1,10 +1,17 @@ import os import pickle as pkl from pathlib import Path -from typing import Dict, Iterable +from typing import Dict +import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from sklearn.metrics import precision_score +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, +) + +from spras.analysis.ml import create_palette class Evaluation: @@ -71,29 +78,164 @@ def load_files_from_dict(self, gold_standard_dict: Dict): # TODO: later iteration - chose between node and edge file, or allow both @staticmethod - def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str): + def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list, dataset_file: str) -> dict: """ - Takes in file paths for a specific dataset and an associated gold standard node table. - Calculates precision for each pathway file - Returns output back to output_file - @param file_paths: file paths of pathway reconstruction algorithm outputs - @param node_table: the gold standard nodes - @param output_file: the filename to save the precision of each pathway + Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files. + A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset + + 1. Prepare a set of default nodes (from the interactome and gold standard) with frequency 0, + ensuring all nodes are represented in the ensemble. + - Answers "Did the algorithm(s) select the correct nodes from the entire network?" + - It measures whether the algorithm(s) can distinguish relevant gold standard nodes + from the full 'universe' of possible nodes present in the input network. + 2. For each edge ensemble file: + a. Read edges and their frequencies. + b. Convert edges frequencies into node-level frequencies for Node1 and Node2. + c. Merge with the default node set and group by node, taking the maximum frequency per node. + 3. Store the resulting node-frequency ensemble under the corresponding ensemble source (label). + + If the interactome or gold standard table is empty, a ValueError is raised. + + @param node_table: dataFrame of gold standard nodes (column: NODEID) + @param ensemble_files: list of file paths containing edge ensemble outputs + @param dataset_file: path to the dataset file used to load the interactome + @return: dictionary mapping each ensemble source to its node ensemble DataFrame """ - y_true = set(node_table['NODEID']) - results = [] - for file in file_paths: - df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"]) - y_pred = set(df['Node1']).union(set(df['Node2'])) - all_nodes = y_true.union(y_pred) - y_true_binary = [1 if node in y_true else 0 for node in all_nodes] - y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes] + node_ensembles_dict = dict() + + pickle = Evaluation.from_file(dataset_file) + interactome = pickle.get_interactome() + + if interactome.empty: + raise ValueError( + f"Cannot compute PR curve or generate node ensemble. Input network for dataset '{dataset_file.split('-')[0]}' is empty." + ) + if node_table.empty: + raise ValueError( + f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset '{dataset_file.split('-')[0]}' is empty." + ) + + # set the initial default frequencies to 0 for all interactome and gold standard nodes + node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'}) + node1_interactome['Frequency'] = 0.0 + node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'}) + node2_interactome['Frequency'] = 0.0 + gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'}) + gs_nodes['Frequency'] = 0.0 + + # combine gold standard and network nodes + other_nodes = pd.concat([node1_interactome, node2_interactome, gs_nodes]) + + for ensemble_file in ensemble_files: + label = Path(ensemble_file).name.split('-')[0] + ensemble_df = pd.read_table(ensemble_file, sep='\t', header=0) - # default to 0.0 if there is a divide by 0 error - precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0) + if not ensemble_df.empty: + node1 = ensemble_df[['Node1', 'Frequency']].rename(columns={'Node1': 'Node'}) + node2 = ensemble_df[['Node2', 'Frequency']].rename(columns={'Node2': 'Node'}) + all_nodes = pd.concat([node1, node2, other_nodes]) + node_ensemble = all_nodes.groupby(['Node']).max().reset_index() + else: + node_ensemble = other_nodes.groupby(['Node']).max().reset_index() - results.append({"Pathway": file, "Precision": precision}) + node_ensembles_dict[label] = node_ensemble - precision_df = pd.DataFrame(results) - precision_df.to_csv(output_file, sep="\t", index=False) + return node_ensembles_dict + + @staticmethod + def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str, + output_file: str): + """ + Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard. + + Takes in a dictionary containing either algorithm-specific node ensembles or an aggregated node ensemble + for a given dataset, along with the corresponding gold standard node table. Computes PR curves for + each ensemble and plots all curves on a single figure. + + @param node_ensembles: dict of the pre-computed node_ensemble(s) + @param node_table: gold standard nodes + @param output_png: filename to save the precision and recall curves as a .png image + @param output_file: filename to save the precision, recall, threshold values, average precision, and baseline precision + """ + gold_standard_nodes = set(node_table[Evaluation.NODE_ID]) + + # make color palette per ensemble label name + label_names = list(node_ensembles) + color_palette = create_palette(label_names) + + plt.figure(figsize=(10, 7)) + + prc_dfs = [] + metric_dfs = [] + + baseline = None + + for label, node_ensemble in node_ensembles.items(): + if not node_ensemble.empty: + y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']] + y_scores = node_ensemble['Frequency'].tolist() + precision, recall, thresholds = precision_recall_curve(y_true, y_scores) + # avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold + avg_precision = average_precision_score(y_true, y_scores) + + # only set baseline precision once + # the same for every algorithm per dataset/goldstandard pair + if baseline is None: + baseline = np.sum(y_true) / len(y_true) + plt.axhline(y=baseline, color="black", linestyle='--', label=f'Baseline: {baseline:.4f}') + + plt.plot(recall, precision, color=color_palette[label], marker='o', + label=f'{label.capitalize()} (AP: {avg_precision:.4f})') + + # Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds + # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve:~:text=Returns%3A-,precision,predictions%20with%20score%20%3E%3D%20thresholds%5Bi%5D%20and%20the%20last%20element%20is%200.,-thresholds + prc_data = { + 'Threshold': thresholds, + 'Precision': precision[:-1], + 'Recall': recall[:-1], + } + + metric_data = { + 'Average_Precision': [avg_precision], + } + + ensemble_source = label.capitalize() if label != 'ensemble' else "Aggregated" + prc_data = {'Ensemble_Source': [ensemble_source] * len(thresholds), **prc_data} + metric_data = {'Ensemble_Source': [ensemble_source], **metric_data} + + prc_df = pd.DataFrame.from_dict(prc_data) + prc_dfs.append(prc_df) + metric_df = pd.DataFrame.from_dict(metric_data) + metric_dfs.append(metric_df) + + else: + raise ValueError( + "Cannot compute PR curve: the ensemble network is empty." + f"This should not happen unless the input network for pathway reconstruction is empty." + ) + + + if 'ensemble' not in label_names: + plt.title('Precision-Recall Curve Per Algorithm Specific Ensemble') + else: + plt.title('Precision-Recall Curve for Aggregated Ensemble Across Algorithms') + + plt.xlim(0, 1) + plt.ylim(0, 1) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.legend(loc='lower left', bbox_to_anchor=(1, 0.5)) + plt.grid(True) + plt.savefig(output_png, bbox_inches='tight') + plt.close() + + combined_prc_df = pd.concat(prc_dfs, ignore_index=True) + combined_metrics_df = pd.concat(metric_dfs, ignore_index=True) + combined_metrics_df["Baseline"] = baseline + + # merge dfs and NaN out metric values except for first row of each Ensemble_Source + complete_df = combined_prc_df.merge(combined_metrics_df, on="Ensemble_Source", how="left") + not_last_rows = complete_df.duplicated(subset="Ensemble_Source", keep='first') + complete_df.loc[not_last_rows, ["Average_Precision", "Baseline"]] = None + complete_df.to_csv(output_file, index=False, sep="\t") diff --git a/test/evaluate/expected/expected-empty-node-ensemble.csv b/test/evaluate/expected/expected-empty-node-ensemble.csv new file mode 100644 index 00000000..15581e4e --- /dev/null +++ b/test/evaluate/expected/expected-empty-node-ensemble.csv @@ -0,0 +1,27 @@ +Node Frequency +A 0.0 +B 0.0 +C 0.0 +D 0.0 +E 0.0 +F 0.0 +G 0.0 +H 0.0 +I 0.0 +J 0.0 +K 0.0 +L 0.0 +M 0.0 +N 0.0 +O 0.0 +P 0.0 +Q 0.0 +R 0.0 +S 0.0 +T 0.0 +U 0.0 +V 0.0 +W 0.0 +X 0.0 +Y 0.0 +Z 0.0 diff --git a/test/evaluate/expected/expected-node-ensemble.csv b/test/evaluate/expected/expected-node-ensemble.csv new file mode 100644 index 00000000..c431a0c0 --- /dev/null +++ b/test/evaluate/expected/expected-node-ensemble.csv @@ -0,0 +1,27 @@ +Node Frequency +A 0.5 +B 0.5 +C 0.0 +D 0.0 +E 0.0 +F 0.0 +G 0.0 +H 0.0 +I 0.0 +J 0.0 +K 0.0 +L 0.0 +M 0.0 +N 0.0 +O 0.0 +P 0.0 +Q 0.01 +R 0.01 +S 0.0 +T 0.0 +U 0.0 +V 0.0 +W 0.0 +X 0.66 +Y 0.0 +Z 0.66 diff --git a/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt b/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt new file mode 100644 index 00000000..c9f6561c --- /dev/null +++ b/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt @@ -0,0 +1,2 @@ +Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Aggregated 0.0 0.15384615384615385 1.0 0.15384615384615385 0.15384615384615385 diff --git a/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt b/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt new file mode 100644 index 00000000..b0e50594 --- /dev/null +++ b/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt @@ -0,0 +1,5 @@ +Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Aggregated 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385 +Aggregated 0.01 0.6666666666666666 1.0 +Aggregated 0.5 0.75 0.75 +Aggregated 0.66 0.5 0.25 diff --git a/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt b/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt new file mode 100644 index 00000000..630a89ce --- /dev/null +++ b/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt @@ -0,0 +1,10 @@ +Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Ensemble1 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385 +Ensemble1 0.01 0.6666666666666666 1.0 +Ensemble1 0.5 0.75 0.75 +Ensemble1 0.66 0.5 0.25 +Ensemble2 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385 +Ensemble2 0.01 0.6666666666666666 1.0 +Ensemble2 0.5 0.75 0.75 +Ensemble2 0.66 0.5 0.25 +Ensemble3 0.0 0.15384615384615385 1.0 0.15384615384615385 0.15384615384615385 diff --git a/test/evaluate/input/data.pickle b/test/evaluate/input/data.pickle new file mode 100644 index 00000000..ad2fee27 Binary files /dev/null and b/test/evaluate/input/data.pickle differ diff --git a/test/evaluate/input/empty-ensemble-network.tsv b/test/evaluate/input/empty-ensemble-network.tsv new file mode 100644 index 00000000..754d8377 --- /dev/null +++ b/test/evaluate/input/empty-ensemble-network.tsv @@ -0,0 +1 @@ +Node1 Node2 Frequency Direction diff --git a/test/evaluate/input/ensemble-network.tsv b/test/evaluate/input/ensemble-network.tsv new file mode 100644 index 00000000..e53c63ec --- /dev/null +++ b/test/evaluate/input/ensemble-network.tsv @@ -0,0 +1,5 @@ +Node1 Node2 Frequency Direction +A B 0.5 U +B A 0.5 U +Q R 0.01 U +Z X 0.66 U \ No newline at end of file diff --git a/test/evaluate/input/gs_node_table.csv b/test/evaluate/input/gs_node_table.csv new file mode 100644 index 00000000..6069522f --- /dev/null +++ b/test/evaluate/input/gs_node_table.csv @@ -0,0 +1,5 @@ +NODEID +A +B +Q +Z \ No newline at end of file diff --git a/test/evaluate/input/node-ensemble-empty.csv b/test/evaluate/input/node-ensemble-empty.csv new file mode 100644 index 00000000..15581e4e --- /dev/null +++ b/test/evaluate/input/node-ensemble-empty.csv @@ -0,0 +1,27 @@ +Node Frequency +A 0.0 +B 0.0 +C 0.0 +D 0.0 +E 0.0 +F 0.0 +G 0.0 +H 0.0 +I 0.0 +J 0.0 +K 0.0 +L 0.0 +M 0.0 +N 0.0 +O 0.0 +P 0.0 +Q 0.0 +R 0.0 +S 0.0 +T 0.0 +U 0.0 +V 0.0 +W 0.0 +X 0.0 +Y 0.0 +Z 0.0 diff --git a/test/evaluate/input/node-ensemble.csv b/test/evaluate/input/node-ensemble.csv new file mode 100644 index 00000000..c431a0c0 --- /dev/null +++ b/test/evaluate/input/node-ensemble.csv @@ -0,0 +1,27 @@ +Node Frequency +A 0.5 +B 0.5 +C 0.0 +D 0.0 +E 0.0 +F 0.0 +G 0.0 +H 0.0 +I 0.0 +J 0.0 +K 0.0 +L 0.0 +M 0.0 +N 0.0 +O 0.0 +P 0.0 +Q 0.01 +R 0.01 +S 0.0 +T 0.0 +U 0.0 +V 0.0 +W 0.0 +X 0.66 +Y 0.0 +Z 0.66 diff --git a/test/evaluate/test_evaluate.py b/test/evaluate/test_evaluate.py new file mode 100644 index 00000000..72962f48 --- /dev/null +++ b/test/evaluate/test_evaluate.py @@ -0,0 +1,85 @@ +import filecmp +from pathlib import Path + +import pandas as pd + +from spras.evaluation import Evaluation + +INPUT_DIR = 'test/evaluate/input/' +OUT_DIR = 'test/evaluate/output/' +EXPECT_DIR = 'test/evaluate/expected/' +GS_NODE_TABLE = pd.read_csv(INPUT_DIR + "gs_node_table.csv", header=0) + + +class TestEvaluate: + @classmethod + def setup_class(cls): + """ + Create the expected output directory + """ + Path(OUT_DIR).mkdir(parents=True, exist_ok=True) + + def test_node_ensemble(self): + out_path_file = Path(OUT_DIR + 'node-ensemble.csv') + out_path_file.unlink(missing_ok=True) + ensemble_network = [INPUT_DIR + 'ensemble-network.tsv'] + input_network = INPUT_DIR + 'data.pickle' + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_network) + node_ensemble_dict['ensemble'].to_csv(out_path_file, sep="\t", index=False) + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) + + def test_empty_node_ensemble(self): + out_path_file = Path(OUT_DIR + 'empty-node-ensemble.csv') + out_path_file.unlink(missing_ok=True) + empty_ensemble_network = [INPUT_DIR + 'empty-ensemble-network.tsv'] + input_network = INPUT_DIR + 'data.pickle' + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network, input_network) + node_ensemble_dict['empty'].to_csv(out_path_file, sep="\t", index=False) + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False) + + def test_multiple_node_ensemble(self): + out_path_file = Path(OUT_DIR + 'node-ensemble.csv') + out_path_file.unlink(missing_ok=True) + out_path_empty_file = Path(OUT_DIR + 'empty-node-ensemble.csv') + out_path_empty_file.unlink(missing_ok=True) + ensemble_networks = [INPUT_DIR + 'ensemble-network.tsv', INPUT_DIR + 'empty-ensemble-network.tsv'] + input_network = INPUT_DIR + 'data.pickle' + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_network) + node_ensemble_dict['ensemble'].to_csv(out_path_file, sep="\t", index=False) + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) + node_ensemble_dict['empty'].to_csv(out_path_empty_file, sep="\t", index=False) + assert filecmp.cmp(out_path_empty_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False) + + def test_precision_recall_curve_ensemble_nodes(self): + out_path_png = Path(OUT_DIR + "pr-curve-ensemble-nodes.png") + out_path_png.unlink(missing_ok=True) + out_path_file = Path(OUT_DIR + "pr-curve-ensemble-nodes.txt") + out_path_file.unlink(missing_ok=True) + ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep="\t", header=0) + node_ensembles_dict = {'ensemble': ensemble_file} + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, str(out_path_png), str(out_path_file)) + assert out_path_png.exists() + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes.txt', shallow=False) + + def test_precision_recall_curve_ensemble_nodes_empty(self): + out_path_png = Path(OUT_DIR+"pr-curve-ensemble-nodes-empty.png") + out_path_png.unlink(missing_ok=True) + out_path_file = Path(OUT_DIR+"pr-curve-ensemble-nodes-empty.txt") + out_path_file.unlink(missing_ok=True) + empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep="\t", header=0) + node_ensembles_dict = {'ensemble': empty_ensemble_file} + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, str(out_path_png), str(out_path_file)) + assert out_path_png.exists() + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes-empty.txt', shallow=False) + + def test_precision_recall_curve_multiple_ensemble_nodes(self): + out_path_png = Path(OUT_DIR+"pr-curve-multiple-ensemble-nodes.png") + out_path_png.unlink(missing_ok=True) + out_path_file = Path(OUT_DIR+"pr-curve-multiple-ensemble-nodes.txt") + out_path_file.unlink(missing_ok=True) + ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep="\t", header=0) + empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep="\t", header=0) + node_ensembles_dict = {'ensemble1': ensemble_file, 'ensemble2': ensemble_file, 'ensemble3': empty_ensemble_file} + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, str(out_path_png), str(out_path_file)) + assert out_path_png.exists() + assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-multiple-ensemble-nodes.txt', shallow=False)