-
Notifications
You must be signed in to change notification settings - Fork 20
Param tuning: ensembling (version 2 but all the same code as version 1) #212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
63f1f80
05d4979
4e60656
7c94c2f
156753c
e07d0a1
2b7424b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -105,8 +105,9 @@ def make_final_input(wildcards): | |||||||
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms)) | ||||||||
|
||||||||
if _config.config.analysis_include_evaluation: | ||||||||
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params)) | ||||||||
|
||||||||
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}precision-recall-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params)) | ||||||||
if _config.config.analysis_include_evaluation_aggregate_algo: | ||||||||
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}{algorithm}-precision-recall-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm=algorithms)) | ||||||||
Comment on lines
+108
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call both of these |
||||||||
if len(final_input) == 0: | ||||||||
# No analysis added yet, so add reconstruction output files if they exist. | ||||||||
# (if analysis is specified, these should be implicitly run). | ||||||||
|
@@ -372,16 +373,36 @@ def get_dataset_label(wildcards): | |||||||
dataset = parts[0] | ||||||||
return dataset | ||||||||
|
||||||||
# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard | ||||||||
|
||||||||
# Run evaluation for each ensemble.txt for a dataset against its paired gold standard | ||||||||
Comment on lines
+376
to
+377
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
rule evaluation: | ||||||||
input: | ||||||||
gold_standard_file = get_gold_standard_pickle_file, | ||||||||
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label), | ||||||||
output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"]) | ||||||||
ensemble_file=lambda wildcards: f"{out_dir}{SEP}{get_dataset_label(wildcards)}-ml{SEP}ensemble-pathway.txt", | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use single quotes in most of the rest of the Snakefile. Also space before and after the |
||||||||
output: | ||||||||
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'precision-recall-curve-ensemble-nodes.png']), | ||||||||
run: | ||||||||
node_table = Evaluation.from_file(input.gold_standard_file).node_table | ||||||||
node_ensemble = Evaluation.edge_frequency_node_ensemble(input.ensemble_file) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My IDE gives a warning here (and in the same line below)
|
||||||||
Evaluation.precision_recall_curve_node_ensemble(node_ensemble, node_table, output.pr_curve_png) | ||||||||
|
||||||||
# Returns ensemble file for a specific algorithm and dataset | ||||||||
def collect_ensemble_per_algo_per_dataset(wildcards): | ||||||||
dataset_label = get_dataset_label(wildcards) | ||||||||
return f"{out_dir}{SEP}{dataset_label}-ml{SEP}{wildcards.algorithm}-ensemble-pathway.txt" | ||||||||
|
||||||||
# Run evaluation per algortihm for each ensemble.txt for a dataset against its paired gold standard | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
rule evaluation_per_algo_ensemble_pr_curve: | ||||||||
input: | ||||||||
gold_standard_file = get_gold_standard_pickle_file, | ||||||||
ensemble_file = collect_ensemble_per_algo_per_dataset, | ||||||||
output: | ||||||||
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', '{algorithm}-precision-recall-curve-ensemble-nodes.png']), | ||||||||
run: | ||||||||
node_table = Evaluation.from_file(input.gold_standard_file).node_table | ||||||||
Evaluation.precision(input.pathways, node_table, output.eval_file) | ||||||||
node_ensemble = Evaluation.edge_frequency_node_ensemble(input.ensemble_file) | ||||||||
Evaluation.precision_recall_curve_node_ensemble(node_ensemble, node_table, output.pr_curve_png) | ||||||||
|
||||||||
# Remove the output directory | ||||||||
rule clean: | ||||||||
shell: f'rm -rf {out_dir}' | ||||||||
shell: f'rm -rf {out_dir}' |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,8 +3,15 @@ | |||||
from pathlib import Path | ||||||
from typing import Dict, Iterable | ||||||
|
||||||
import matplotlib.pyplot as plt | ||||||
import numpy as np | ||||||
import pandas as pd | ||||||
from sklearn.metrics import precision_score | ||||||
from sklearn.metrics import ( | ||||||
average_precision_score, | ||||||
precision_recall_curve, | ||||||
precision_score, | ||||||
recall_score, | ||||||
) | ||||||
|
||||||
|
||||||
class Evaluation: | ||||||
|
@@ -71,30 +78,78 @@ def load_files_from_dict(self, gold_standard_dict: Dict): | |||||
|
||||||
# TODO: later iteration - chose between node and edge file, or allow both | ||||||
|
||||||
@staticmethod | ||||||
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we dropping support for precision calculations? |
||||||
def select_max_freq_and_node(row: pd.Series): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IDE warning for all of these new functions
|
||||||
""" | ||||||
Selects the node and frequency with the highest frequency value from two potential nodes in a row. | ||||||
Handles cases where one of the nodes or frequencies may be missing and returns the node associated with the maximum frequency. | ||||||
""" | ||||||
max_freq = 0 | ||||||
node = "" | ||||||
Comment on lines
+86
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't used currently |
||||||
if pd.isna(row['Node2']) and pd.isna(row['Freq2']): | ||||||
max_freq = row['Freq1'] | ||||||
node = row['Node1'] | ||||||
elif pd.isna(row['Node1']) and pd.isna(row['Freq1']): | ||||||
max_freq = row['Freq2'] | ||||||
node = row['Node2'] | ||||||
else: | ||||||
max_freq = max(row['Freq1'], row['Freq2']) | ||||||
node = row['Node1'] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why Node1 if Freq2 is the max? Can we ever have N/A for both nodes? |
||||||
return node, max_freq | ||||||
|
||||||
def edge_frequency_node_ensemble(ensemble_file: str): | ||||||
""" | ||||||
Processes an ensemble of edge frequencies to identify the highest frequency associated with each node | ||||||
Reads ensemble_file, separates frequencies by node, and then calculates the maximum frequency for each node. | ||||||
Returns a DataFrame of nodes with their respective maximum frequencies, or an empty DataFrame if ensemble_file is empty. | ||||||
@param ensemble_file: the pre-computed node_ensemble | ||||||
""" | ||||||
ensemble_df = pd.read_table(ensemble_file, sep="\t", header=0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if not ensemble_df.empty: | ||||||
node1_freq = ensemble_df.drop(columns = ['Node2', 'Direction']) | ||||||
node2_freq = ensemble_df.drop(columns = ['Node1', 'Direction']) | ||||||
|
||||||
max_node1_freq = node1_freq.groupby(['Node1']).max().reset_index() | ||||||
max_node1_freq.rename(columns = {'Frequency': 'Freq1'}, inplace = True) | ||||||
max_node2_freq = node2_freq.groupby(['Node2']).max().reset_index() | ||||||
max_node2_freq.rename(columns = {'Frequency': 'Freq2'}, inplace = True) | ||||||
|
||||||
node_ensemble = max_node1_freq.merge(max_node2_freq, left_on='Node1', right_on='Node2', how='outer') | ||||||
node_ensemble[['Node', 'max_freq']] = node_ensemble.apply(Evaluation.select_max_freq_and_node, axis=1, result_type='expand') | ||||||
node_ensemble.drop(columns = ['Node1', 'Node2', 'Freq1', 'Freq2'], inplace = True) | ||||||
node_ensemble.sort_values('max_freq', ascending= False, inplace = True) | ||||||
Comment on lines
+112
to
+120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic feels overly complicated. Can you comment it to explain all the steps? My understanding of what is needed is that we want the max frequency for all nodes but want to account for a node appearing at Node1 or Node2. Could we do:
|
||||||
return node_ensemble | ||||||
else: | ||||||
return pd.DataFrame(columns = ['Node', 'max_freq']) | ||||||
|
||||||
def precision_recall_curve_node_ensemble(node_ensemble:pd.DataFrame, node_table:pd.DataFrame, output_png: str): | ||||||
""" | ||||||
Takes in file paths for a specific dataset and an associated gold standard node table. | ||||||
Calculates precision for each pathway file | ||||||
Returns output back to output_file | ||||||
@param file_paths: file paths of pathway reconstruction algorithm outputs | ||||||
Takes in an node ensemble for specific dataset or specific algorithm in a dataset, and an associated gold standard node table. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Plots a precision and recall curve for the node ensemble against its associated gold standard node table | ||||||
Returns output back to output_png | ||||||
@param node_ensemble: the pre-computed node_ensemble | ||||||
@param node_table: the gold standard nodes | ||||||
@param output_file: the filename to save the precision of each pathway | ||||||
@param output_file: the filename to save the precision and recall curves | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
""" | ||||||
y_true = set(node_table['NODEID']) | ||||||
results = [] | ||||||
|
||||||
for file in file_paths: | ||||||
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"]) | ||||||
y_pred = set(df['Node1']).union(set(df['Node2'])) | ||||||
all_nodes = y_true.union(y_pred) | ||||||
y_true_binary = [1 if node in y_true else 0 for node in all_nodes] | ||||||
y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes] | ||||||
|
||||||
# default to 0.0 if there is a divide by 0 error | ||||||
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0) | ||||||
|
||||||
results.append({"Pathway": file, "Precision": precision}) | ||||||
|
||||||
precision_df = pd.DataFrame(results) | ||||||
precision_df.to_csv(output_file, sep="\t", index=False) | ||||||
gold_standard_nodes = set(node_table['NODEID']) | ||||||
|
||||||
if not node_ensemble.empty: | ||||||
y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']] | ||||||
y_scores = node_ensemble['max_freq'].tolist() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the node ensemble does not include all of the gold standard nodes, this will not achieve full recall. That will be biased toward some methods that make more conservative predictions, right? |
||||||
precision, recall, thresholds = precision_recall_curve(y_true, y_scores) | ||||||
auc_precision_recall = average_precision_score(y_true, y_scores) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's call this average precision. AUC of a precision recall curve is different, and I don't want there to be confusion later when reading the code. |
||||||
|
||||||
plt.figure() | ||||||
plt.plot(recall, precision, marker='o', label='Precision-Recall curve') | ||||||
plt.axhline(y=auc_precision_recall, color='r', linestyle='--', label=f'Avg Precision: {auc_precision_recall:.4f}') | ||||||
Comment on lines
+143
to
+144
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines aren't what I thought they were at first. I thought the red horizontal line was the baseline PR curve, which is (number positives)/(number instances). I don't think we should plot the average precision as a horizontal line. Typically we would plot the baseline instead and put the average precision in the legend by the blue line. |
||||||
plt.xlabel('Recall') | ||||||
plt.ylabel('Precision') | ||||||
plt.title('Precision-Recall Curve') | ||||||
plt.legend() | ||||||
plt.grid(True) | ||||||
plt.savefig(output_png) | ||||||
else: | ||||||
plt.figure() | ||||||
plt.plot([], []) | ||||||
plt.title("Empty Ensemble File") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the title be the same? Should we print a warning message instead? |
||||||
plt.savefig(output_png) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Node max_freq | ||
C 0.75 | ||
E 0.75 | ||
D 0.75 | ||
F 0.75 | ||
A 0.5 | ||
B 0.5 | ||
L 0.5 | ||
M 0.5 | ||
O 0.25 | ||
P 0.25 | ||
N 0.25 | ||
Q 0.25 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
Node1 Node2 Frequency Direction | ||
A B 0.5 U | ||
C D 0.75 U | ||
E F 0.75 U | ||
L M 0.5 U | ||
M N 0.25 U | ||
O P 0.25 U | ||
P Q 0.25 U | ||
A B 0.25 D | ||
B A 0.25 D |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Node max_freq | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Node max_freq | ||
C 0.75 | ||
E 0.75 | ||
D 0.75 | ||
F 0.75 | ||
A 0.5 | ||
B 0.5 | ||
L 0.5 | ||
M 0.5 | ||
O 0.25 | ||
P 0.25 | ||
N 0.25 | ||
Q 0.25 |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is the gold standard, let's indicate that in the filename. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
NODEID | ||
A | ||
B | ||
Q |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import filecmp | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
import spras.analysis.ml as ml | ||
from spras.evaluation import Evaluation | ||
|
||
INPUT_DIR = 'test/evaluate/input/' | ||
OUT_DIR = 'test/evaluate/output/' | ||
EXPECT_DIR = 'test/evaluate/expected/' | ||
NODE_TABLE = pd.read_csv(INPUT_DIR + "node_table.csv", header=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
class TestEvaluate: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we not have any test code for the precision functionality before? I'm wondering if there are other expected files we need to clean up. |
||
@classmethod | ||
def setup_class(cls): | ||
""" | ||
Create the expected output directory | ||
""" | ||
Path(OUT_DIR).mkdir(parents=True, exist_ok=True) | ||
|
||
def test_node_ensemble(self): | ||
ensemble_file = INPUT_DIR + 'ensemble-network.tsv' | ||
edge_freq = Evaluation.edge_frequency_node_ensemble(ensemble_file) | ||
edge_freq.to_csv(OUT_DIR + 'node-ensemble.csv', sep="\t", index=False) | ||
assert filecmp.cmp(OUT_DIR + 'node-ensemble.csv', EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) | ||
|
||
def test_precision_recal_curve_ensemble_nodes(self): | ||
out_path = Path(OUT_DIR+"test-precision-recall-curve-ensemble-nodes.png") | ||
out_path.unlink(missing_ok=True) | ||
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep="\t", header=0) | ||
Evaluation.precision_recall_curve_node_ensemble(ensemble_file, NODE_TABLE, out_path) | ||
assert out_path.exists() | ||
|
||
def test_precision_recal_curve_ensemble_nodes_empty(self): | ||
out_path = Path(OUT_DIR+"test-precision-recall-curve-ensemble-nodes-empty.png") | ||
out_path.unlink(missing_ok=True) | ||
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep="\t", header=0) | ||
Evaluation.precision_recall_curve_node_ensemble(ensemble_file, NODE_TABLE, out_path) | ||
assert out_path.exists() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
algorithm_params=algorithms_with_params
part isn't being used and can be deleted