Skip to content

Param tuning: ensembling (version 2 but all the same code as version 1) #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
63f1f80
test cases for dup edges
ntalluri Mar 3, 2025
05d4979
removed dup edge cases from master
ntalluri Mar 3, 2025
4e60656
Merge branch 'master' of github.com:ntalluri/spras
ntalluri Mar 10, 2025
7c94c2f
Merge branch 'master' of github.com:ntalluri/spras
ntalluri Mar 24, 2025
156753c
updated SnakeFile
ntalluri Mar 24, 2025
e07d0a1
added test cases for ensemble node PR curves
ntalluri Mar 24, 2025
2b7424b
Line spacing fixes
agitter May 2, 2025
ee8d921
updated Snakefile based on review
ntalluri May 28, 2025
248ebd4
update eval code based on review
ntalluri May 28, 2025
721e5e0
updated test case based on review
ntalluri May 28, 2025
af92fc3
Merge branch 'param-tuning-ensembling-2.0' of github.com:ntalluri/spr…
ntalluri May 28, 2025
e231a17
update ensemble visulization snakemake logic
ntalluri May 29, 2025
48f6df2
update ensemble logic
ntalluri May 29, 2025
5a7c4d3
update test cases
ntalluri May 29, 2025
50dea36
update ensemble eval code to include file of the values
ntalluri May 29, 2025
1c7a604
updated test cases
ntalluri May 29, 2025
d9e4400
clean up code and comments
ntalluri Jun 5, 2025
89c2508
update test cases
ntalluri Jun 5, 2025
c98f0df
Clean up imports and formatting
agitter Jun 7, 2025
b3a7ac0
update to evaluation code
ntalluri Jun 18, 2025
c3fe362
updated test cases
ntalluri Jun 18, 2025
8bd6133
precommit
ntalluri Jun 18, 2025
5a45f03
Formatting and doc changes
agitter Jun 20, 2025
6f51672
update to evaluation code based on review
ntalluri Jun 24, 2025
7f32a0b
update to test cases and updated based on review
ntalluri Jun 24, 2025
b4d6ef3
precommit
ntalluri Jun 24, 2025
5f8a991
added eval dataset and added a couple more parameters
ntalluri Jun 25, 2025
db011f0
Merge branch 'main' into param-tuning-ensembling-2.0
ntalluri Jun 25, 2025
3d22a0a
Update egfr.yaml
ntalluri Jun 25, 2025
1dea550
precommit
ntalluri Jun 26, 2025
cd538e1
Sync EGFR config file with general changes
agitter Jul 11, 2025
b6a86df
Formatting
agitter Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ def make_final_input(wildcards):
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-jaccard-heatmap.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms))

if _config.config.analysis_include_evaluation:
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params))

final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
if _config.config.analysis_include_evaluation_aggregate_algo:
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
if len(final_input) == 0:
# No analysis added yet, so add reconstruction output files if they exist.
# (if analysis is specified, these should be implicitly run).
Expand Down Expand Up @@ -392,23 +395,56 @@ def get_gold_standard_pickle_file(wildcards):
parts = wildcards.dataset_gold_standard_pairs.split('-')
gs = parts[1]
return SEP.join([out_dir, f'{gs}-merged.pickle'])

# Returns the dataset corresponding to the gold standard pair
def get_dataset_label(wildcards):
parts = wildcards.dataset_gold_standard_pairs.split('-')
dataset = parts[0]
return dataset

# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard
rule evaluation:
# Return the dataset pickle file for a specific dataset
def get_dataset_pickle_file(wildcards):
dataset_label = get_dataset_label(wildcards)
return SEP.join([out_dir, f'{dataset_label}-merged.pickle'])

# Returns ensemble file for each dataset
def collect_ensemble_per_dataset(wildcards):
dataset_label = get_dataset_label(wildcards)
return expand('{out_dir}{sep}{dataset}-ml{sep}ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label)

# Run precision-recall curves for each ensemble pathway within a dataset evaluated against its corresponding gold standard
rule evaluation_ensemble_pr_curve:
input:
gold_standard_file = get_gold_standard_pickle_file,
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label),
output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"])
dataset_file = get_dataset_pickle_file,
ensemble_file = collect_ensemble_per_dataset
output:
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.png']),
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.txt']),
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file)

# Returns list of algorithm specific ensemble files per dataset
def collect_ensemble_per_algo_per_dataset(wildcards):
dataset_label = get_dataset_label(wildcards)
return expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label, algorithm=algorithms)

# Run precision-recall curves for each algorithm's ensemble pathway within a dataset evaluated against its corresponding gold standard
rule evaluation_per_algo_ensemble_pr_curve:
input:
gold_standard_file = get_gold_standard_pickle_file,
dataset_file = get_dataset_pickle_file,
ensemble_files = collect_ensemble_per_algo_per_dataset
output:
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.png']),
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.txt']),
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
Evaluation.precision(input.pathways, node_table, output.eval_file)
node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file)

# Remove the output directory
rule clean:
shell: f'rm -rf {out_dir}'
shell: f'rm -rf {out_dir}'
70 changes: 63 additions & 7 deletions config/egfr.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
# The length of the hash used to identify a parameter combination
hash_length: 7

# If true, use Singularity instead of Docker
# Singularity support is only available on Unix
singularity: false
# Specify the container framework used by each PRM wrapper. Valid options include:
# - docker (default if not specified)
# - singularity -- Also known as apptainer, useful in HPC/HTC environments where docker isn't allowed
# - dsub -- experimental with limited support, used for running on Google Cloud
container_framework: docker

# Only used if container_framework is set to singularity, this will unpack the singularity containers
# to the local filesystem. This is useful when PRM containers need to run inside another container,
# such as would be the case in an HTCondor/OSPool environment.
# NOTE: This unpacks singularity containers to the local filesystem, which will take up space in a way
# that persists after the workflow is complete. To clean up the unpacked containers, the user must
# manually delete them.
unpack_singularity: false

# Allow the user to configure which container registry containers should be pulled from
# Note that this assumes container names are consistent across registries, and that the
# registry being passed doesn't require authentication for pull actions
container_registry:
base_url: docker.io
# The owner or project of the registry
# For example, "reedcompbio" if the image is available as docker.io/reedcompbio/allpairs
owner: reedcompbio

algorithms:
- name: pathlinker
Expand All @@ -13,6 +32,7 @@ algorithms:
k:
- 10
- 20
- 70
- name: omicsintegrator1
params:
include: true
Expand Down Expand Up @@ -55,6 +75,16 @@ algorithms:
- 3
rand_restarts:
- 10
run2:
local_search:
- "No"
max_path_length:
- 2
rand_restarts:
- 10
- name: allpairs
params:
include: true
- name: domino
params:
include: true
Expand All @@ -63,6 +93,24 @@ algorithms:
- 0.3
module_threshold:
- 0.05
- name: mincostflow
params:
include: true
run1:
capacity:
- 15
flow:
- 80
run2:
capacity:
- 1
flow:
- 6
run3:
capacity:
- 5
flow:
- 60
datasets:
- data_dir: input
edge_files:
Expand All @@ -71,18 +119,26 @@ datasets:
node_files:
- tps-egfr-prizes.txt
other_files: []
gold_standards:
- label: gs_egfr
node_files:
- gs-egfr.txt
data_dir: input
dataset_labels:
- tps_egfr
reconstruction_settings:
locations:
reconstruction_dir: output/egfr
run: true
analysis:
graphspace:
include: false
cytoscape:
include: true
summary:
include: true
ml:
include: false
include: true
aggregate_per_algorithm: true
labels: true
evaluation:
include: false
include: true
aggregate_per_algorithm: true
186 changes: 164 additions & 22 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
import pickle as pkl
from pathlib import Path
from typing import Dict, Iterable
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
)

from spras.analysis.ml import create_palette


class Evaluation:
Expand Down Expand Up @@ -71,29 +78,164 @@ def load_files_from_dict(self, gold_standard_dict: Dict):
# TODO: later iteration - chose between node and edge file, or allow both

@staticmethod
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str):
def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list, dataset_file: str) -> dict:
"""
Takes in file paths for a specific dataset and an associated gold standard node table.
Calculates precision for each pathway file
Returns output back to output_file
@param file_paths: file paths of pathway reconstruction algorithm outputs
@param node_table: the gold standard nodes
@param output_file: the filename to save the precision of each pathway
Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files.
A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset

1. Prepare a set of default nodes (from the interactome and gold standard) with frequency 0,
ensuring all nodes are represented in the ensemble.
- Answers "Did the algorithm(s) select the correct nodes from the entire network?"
- It measures whether the algorithm(s) can distinguish relevant gold standard nodes
from the full 'universe' of possible nodes present in the input network.
2. For each edge ensemble file:
a. Read edges and their frequencies.
b. Convert edges frequencies into node-level frequencies for Node1 and Node2.
c. Merge with the default node set and group by node, taking the maximum frequency per node.
3. Store the resulting node-frequency ensemble under the corresponding ensemble source (label).

If the interactome or gold standard table is empty, a ValueError is raised.

@param node_table: dataFrame of gold standard nodes (column: NODEID)
@param ensemble_files: list of file paths containing edge ensemble outputs
@param dataset_file: path to the dataset file used to load the interactome
@return: dictionary mapping each ensemble source to its node ensemble DataFrame
"""
y_true = set(node_table['NODEID'])
results = []

for file in file_paths:
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
y_pred = set(df['Node1']).union(set(df['Node2']))
all_nodes = y_true.union(y_pred)
y_true_binary = [1 if node in y_true else 0 for node in all_nodes]
y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes]
node_ensembles_dict = dict()

pickle = Evaluation.from_file(dataset_file)
interactome = pickle.get_interactome()

if interactome.empty:
raise ValueError(
f"Cannot compute PR curve or generate node ensemble. Input network for dataset '{dataset_file.split('-')[0]}' is empty."
)
if node_table.empty:
raise ValueError(
f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset '{dataset_file.split('-')[0]}' is empty."
)

# set the initial default frequencies to 0 for all interactome and gold standard nodes
node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'})
node1_interactome['Frequency'] = 0.0
node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'})
node2_interactome['Frequency'] = 0.0
gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'})
gs_nodes['Frequency'] = 0.0

# combine gold standard and network nodes
other_nodes = pd.concat([node1_interactome, node2_interactome, gs_nodes])

for ensemble_file in ensemble_files:
label = Path(ensemble_file).name.split('-')[0]
ensemble_df = pd.read_table(ensemble_file, sep='\t', header=0)

# default to 0.0 if there is a divide by 0 error
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0)
if not ensemble_df.empty:
node1 = ensemble_df[['Node1', 'Frequency']].rename(columns={'Node1': 'Node'})
node2 = ensemble_df[['Node2', 'Frequency']].rename(columns={'Node2': 'Node'})
all_nodes = pd.concat([node1, node2, other_nodes])
node_ensemble = all_nodes.groupby(['Node']).max().reset_index()
else:
node_ensemble = other_nodes.groupby(['Node']).max().reset_index()

results.append({"Pathway": file, "Precision": precision})
node_ensembles_dict[label] = node_ensemble

precision_df = pd.DataFrame(results)
precision_df.to_csv(output_file, sep="\t", index=False)
return node_ensembles_dict

@staticmethod
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str,
output_file: str):
"""
Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard.

Takes in a dictionary containing either algorithm-specific node ensembles or an aggregated node ensemble
for a given dataset, along with the corresponding gold standard node table. Computes PR curves for
each ensemble and plots all curves on a single figure.

@param node_ensembles: dict of the pre-computed node_ensemble(s)
@param node_table: gold standard nodes
@param output_png: filename to save the precision and recall curves as a .png image
@param output_file: filename to save the precision, recall, threshold values, average precision, and baseline
average precision
"""
gold_standard_nodes = set(node_table[Evaluation.NODE_ID])

# make color palette per ensemble label name
label_names = list(node_ensembles.keys())
color_palette = create_palette(label_names)

plt.figure(figsize=(10, 7))

prc_dfs = []
metric_dfs = []

baseline = None

for label, node_ensemble in node_ensembles.items():
if not node_ensemble.empty:
y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']]
y_scores = node_ensemble['Frequency'].tolist()
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
# avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold
avg_precision = average_precision_score(y_true, y_scores)

# only set baseline precision once
# the same for every algorithm per dataset/goldstandard pair
if baseline is None:
baseline = np.sum(y_true) / len(y_true)
plt.axhline(y=baseline, color="black", linestyle='--', label=f'Baseline: {baseline:.4f}')

plt.plot(recall, precision, color=color_palette[label], marker='o',
label=f'{label.capitalize()} (AP: {avg_precision:.4f})')

# Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve:~:text=Returns%3A-,precision,predictions%20with%20score%20%3E%3D%20thresholds%5Bi%5D%20and%20the%20last%20element%20is%200.,-thresholds
prc_data = {
'Threshold': thresholds,
'Precision': precision[:-1],
'Recall': recall[:-1],
}

metric_data = {
'Average_Precision': [avg_precision],
}

ensemble_source = label.capitalize() if label != 'ensemble' else "Aggregated"
prc_data = {'Ensemble_Source': [ensemble_source] * len(thresholds), **prc_data}
metric_data = {'Ensemble_Source': [ensemble_source], **metric_data}

prc_df = pd.DataFrame.from_dict(prc_data)
prc_dfs.append(prc_df)
metric_df = pd.DataFrame.from_dict(metric_data)
metric_dfs.append(metric_df)

else:
raise ValueError(
"Cannot compute PR curve: the ensemble network is empty."
f"This should not happen unless the input network for pathway reconstruction is empty."
)

if 'ensemble' not in label_names:
plt.title('Precision-Recall Curve Per Algorithm Specific Ensemble')
else:
plt.title('Precision-Recall Curve for Aggregated Ensemble Across Algorithms')

plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend(loc='lower left', bbox_to_anchor=(1, 0.5))
plt.grid(True)
plt.savefig(output_png, bbox_inches='tight')
plt.close()

combined_prc_df = pd.concat(prc_dfs, ignore_index=True)
combined_metrics_df = pd.concat(metric_dfs, ignore_index=True)
combined_metrics_df["Baseline"] = baseline

# merge dfs and NaN out metric values except for first row of each Ensemble_Source
complete_df = combined_prc_df.merge(combined_metrics_df, on="Ensemble_Source", how="left")
not_last_rows = complete_df.duplicated(subset="Ensemble_Source", keep='first')
complete_df.loc[not_last_rows, ["Average_Precision", "Baseline"]] = None
complete_df.to_csv(output_file, index=False, sep="\t")
Loading
Loading