diff --git a/Snakefile b/Snakefile index 1825b6d1..9ceb0122 100644 --- a/Snakefile +++ b/Snakefile @@ -105,8 +105,9 @@ def make_final_input(wildcards): final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms)) if _config.config.analysis_include_evaluation: - final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params)) - + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}precision-recall-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params)) + if _config.config.analysis_include_evaluation_aggregate_algo: + final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}{algorithm}-precision-recall-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm=algorithms)) if len(final_input) == 0: # No analysis added yet, so add reconstruction output files if they exist. # (if analysis is specified, these should be implicitly run). @@ -372,16 +373,36 @@ def get_dataset_label(wildcards): dataset = parts[0] return dataset -# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard + +# Run evaluation for each ensemble.txt for a dataset against its paired gold standard rule evaluation: input: gold_standard_file = get_gold_standard_pickle_file, - pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label), - output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"]) + ensemble_file=lambda wildcards: f"{out_dir}{SEP}{get_dataset_label(wildcards)}-ml{SEP}ensemble-pathway.txt", + output: + pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'precision-recall-curve-ensemble-nodes.png']), + run: + node_table = Evaluation.from_file(input.gold_standard_file).node_table + node_ensemble = Evaluation.edge_frequency_node_ensemble(input.ensemble_file) + Evaluation.precision_recall_curve_node_ensemble(node_ensemble, node_table, output.pr_curve_png) + +# Returns ensemble file for a specific algorithm and dataset +def collect_ensemble_per_algo_per_dataset(wildcards): + dataset_label = get_dataset_label(wildcards) + return f"{out_dir}{SEP}{dataset_label}-ml{SEP}{wildcards.algorithm}-ensemble-pathway.txt" + +# Run evaluation per algortihm for each ensemble.txt for a dataset against its paired gold standard +rule evaluation_per_algo_ensemble_pr_curve: + input: + gold_standard_file = get_gold_standard_pickle_file, + ensemble_file = collect_ensemble_per_algo_per_dataset, + output: + pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', '{algorithm}-precision-recall-curve-ensemble-nodes.png']), run: node_table = Evaluation.from_file(input.gold_standard_file).node_table - Evaluation.precision(input.pathways, node_table, output.eval_file) + node_ensemble = Evaluation.edge_frequency_node_ensemble(input.ensemble_file) + Evaluation.precision_recall_curve_node_ensemble(node_ensemble, node_table, output.pr_curve_png) # Remove the output directory rule clean: - shell: f'rm -rf {out_dir}' \ No newline at end of file + shell: f'rm -rf {out_dir}' diff --git a/spras/evaluation.py b/spras/evaluation.py index 5d00e7d4..cec11abe 100644 --- a/spras/evaluation.py +++ b/spras/evaluation.py @@ -3,8 +3,15 @@ from pathlib import Path from typing import Dict, Iterable +import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from sklearn.metrics import precision_score +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, + precision_score, + recall_score, +) class Evaluation: @@ -71,30 +78,78 @@ def load_files_from_dict(self, gold_standard_dict: Dict): # TODO: later iteration - chose between node and edge file, or allow both - @staticmethod - def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str): + def select_max_freq_and_node(row: pd.Series): + """ + Selects the node and frequency with the highest frequency value from two potential nodes in a row. + Handles cases where one of the nodes or frequencies may be missing and returns the node associated with the maximum frequency. + """ + max_freq = 0 + node = "" + if pd.isna(row['Node2']) and pd.isna(row['Freq2']): + max_freq = row['Freq1'] + node = row['Node1'] + elif pd.isna(row['Node1']) and pd.isna(row['Freq1']): + max_freq = row['Freq2'] + node = row['Node2'] + else: + max_freq = max(row['Freq1'], row['Freq2']) + node = row['Node1'] + return node, max_freq + + def edge_frequency_node_ensemble(ensemble_file: str): + """ + Processes an ensemble of edge frequencies to identify the highest frequency associated with each node + Reads ensemble_file, separates frequencies by node, and then calculates the maximum frequency for each node. + Returns a DataFrame of nodes with their respective maximum frequencies, or an empty DataFrame if ensemble_file is empty. + @param ensemble_file: the pre-computed node_ensemble + """ + ensemble_df = pd.read_table(ensemble_file, sep="\t", header=0) + + if not ensemble_df.empty: + node1_freq = ensemble_df.drop(columns = ['Node2', 'Direction']) + node2_freq = ensemble_df.drop(columns = ['Node1', 'Direction']) + + max_node1_freq = node1_freq.groupby(['Node1']).max().reset_index() + max_node1_freq.rename(columns = {'Frequency': 'Freq1'}, inplace = True) + max_node2_freq = node2_freq.groupby(['Node2']).max().reset_index() + max_node2_freq.rename(columns = {'Frequency': 'Freq2'}, inplace = True) + + node_ensemble = max_node1_freq.merge(max_node2_freq, left_on='Node1', right_on='Node2', how='outer') + node_ensemble[['Node', 'max_freq']] = node_ensemble.apply(Evaluation.select_max_freq_and_node, axis=1, result_type='expand') + node_ensemble.drop(columns = ['Node1', 'Node2', 'Freq1', 'Freq2'], inplace = True) + node_ensemble.sort_values('max_freq', ascending= False, inplace = True) + return node_ensemble + else: + return pd.DataFrame(columns = ['Node', 'max_freq']) + + def precision_recall_curve_node_ensemble(node_ensemble:pd.DataFrame, node_table:pd.DataFrame, output_png: str): """ - Takes in file paths for a specific dataset and an associated gold standard node table. - Calculates precision for each pathway file - Returns output back to output_file - @param file_paths: file paths of pathway reconstruction algorithm outputs + Takes in an node ensemble for specific dataset or specific algorithm in a dataset, and an associated gold standard node table. + Plots a precision and recall curve for the node ensemble against its associated gold standard node table + Returns output back to output_png + @param node_ensemble: the pre-computed node_ensemble @param node_table: the gold standard nodes - @param output_file: the filename to save the precision of each pathway + @param output_file: the filename to save the precision and recall curves """ - y_true = set(node_table['NODEID']) - results = [] - - for file in file_paths: - df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"]) - y_pred = set(df['Node1']).union(set(df['Node2'])) - all_nodes = y_true.union(y_pred) - y_true_binary = [1 if node in y_true else 0 for node in all_nodes] - y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes] - - # default to 0.0 if there is a divide by 0 error - precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0) - - results.append({"Pathway": file, "Precision": precision}) - - precision_df = pd.DataFrame(results) - precision_df.to_csv(output_file, sep="\t", index=False) + gold_standard_nodes = set(node_table['NODEID']) + + if not node_ensemble.empty: + y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']] + y_scores = node_ensemble['max_freq'].tolist() + precision, recall, thresholds = precision_recall_curve(y_true, y_scores) + auc_precision_recall = average_precision_score(y_true, y_scores) + + plt.figure() + plt.plot(recall, precision, marker='o', label='Precision-Recall curve') + plt.axhline(y=auc_precision_recall, color='r', linestyle='--', label=f'Avg Precision: {auc_precision_recall:.4f}') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.legend() + plt.grid(True) + plt.savefig(output_png) + else: + plt.figure() + plt.plot([], []) + plt.title("Empty Ensemble File") + plt.savefig(output_png) diff --git a/test/evaluate/expected/expected-node-ensemble.csv b/test/evaluate/expected/expected-node-ensemble.csv new file mode 100644 index 00000000..ba467d55 --- /dev/null +++ b/test/evaluate/expected/expected-node-ensemble.csv @@ -0,0 +1,13 @@ +Node max_freq +C 0.75 +E 0.75 +D 0.75 +F 0.75 +A 0.5 +B 0.5 +L 0.5 +M 0.5 +O 0.25 +P 0.25 +N 0.25 +Q 0.25 diff --git a/test/evaluate/input/ensemble-network.tsv b/test/evaluate/input/ensemble-network.tsv new file mode 100644 index 00000000..293ec3f5 --- /dev/null +++ b/test/evaluate/input/ensemble-network.tsv @@ -0,0 +1,10 @@ +Node1 Node2 Frequency Direction +A B 0.5 U +C D 0.75 U +E F 0.75 U +L M 0.5 U +M N 0.25 U +O P 0.25 U +P Q 0.25 U +A B 0.25 D +B A 0.25 D \ No newline at end of file diff --git a/test/evaluate/input/node-ensemble-empty.csv b/test/evaluate/input/node-ensemble-empty.csv new file mode 100644 index 00000000..e488f56a --- /dev/null +++ b/test/evaluate/input/node-ensemble-empty.csv @@ -0,0 +1,2 @@ +Node max_freq + diff --git a/test/evaluate/input/node-ensemble.csv b/test/evaluate/input/node-ensemble.csv new file mode 100644 index 00000000..ba467d55 --- /dev/null +++ b/test/evaluate/input/node-ensemble.csv @@ -0,0 +1,13 @@ +Node max_freq +C 0.75 +E 0.75 +D 0.75 +F 0.75 +A 0.5 +B 0.5 +L 0.5 +M 0.5 +O 0.25 +P 0.25 +N 0.25 +Q 0.25 diff --git a/test/evaluate/input/node_table.csv b/test/evaluate/input/node_table.csv new file mode 100644 index 00000000..690ad094 --- /dev/null +++ b/test/evaluate/input/node_table.csv @@ -0,0 +1,4 @@ +NODEID +A +B +Q diff --git a/test/evaluate/test_evaluate.py b/test/evaluate/test_evaluate.py new file mode 100644 index 00000000..c2496399 --- /dev/null +++ b/test/evaluate/test_evaluate.py @@ -0,0 +1,40 @@ +import filecmp +from pathlib import Path + +import pandas as pd +import pytest + +import spras.analysis.ml as ml +from spras.evaluation import Evaluation + +INPUT_DIR = 'test/evaluate/input/' +OUT_DIR = 'test/evaluate/output/' +EXPECT_DIR = 'test/evaluate/expected/' +NODE_TABLE = pd.read_csv(INPUT_DIR + "node_table.csv", header=0) +class TestEvaluate: + @classmethod + def setup_class(cls): + """ + Create the expected output directory + """ + Path(OUT_DIR).mkdir(parents=True, exist_ok=True) + + def test_node_ensemble(self): + ensemble_file = INPUT_DIR + 'ensemble-network.tsv' + edge_freq = Evaluation.edge_frequency_node_ensemble(ensemble_file) + edge_freq.to_csv(OUT_DIR + 'node-ensemble.csv', sep="\t", index=False) + assert filecmp.cmp(OUT_DIR + 'node-ensemble.csv', EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) + + def test_precision_recal_curve_ensemble_nodes(self): + out_path = Path(OUT_DIR+"test-precision-recall-curve-ensemble-nodes.png") + out_path.unlink(missing_ok=True) + ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep="\t", header=0) + Evaluation.precision_recall_curve_node_ensemble(ensemble_file, NODE_TABLE, out_path) + assert out_path.exists() + + def test_precision_recal_curve_ensemble_nodes_empty(self): + out_path = Path(OUT_DIR+"test-precision-recall-curve-ensemble-nodes-empty.png") + out_path.unlink(missing_ok=True) + ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep="\t", header=0) + Evaluation.precision_recall_curve_node_ensemble(ensemble_file, NODE_TABLE, out_path) + assert out_path.exists()