Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"scanpy",
"biothings_client<0.4.0",
"memelite",
"lxml"
]

[project.optional-dependencies]
Expand Down
11 changes: 10 additions & 1 deletion src/crested/pl/patterns/_contribution_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def contribution_scores(
zoom_n_bases: int | None = None,
highlight_positions: list[tuple[int, int]] | None = None,
ylim: tuple | None = None,
x_shift: int = 0,
method: str | None = None,
**kwargs,
):
Expand All @@ -73,6 +74,8 @@ def contribution_scores(
List of tuples with start and end positions to highlight. Default is None.
ylim
Y-axis limits. Default is None.
x_shift
Number of base pairs to shift left or right for visualizing specific subsets of the region. Only use when combined with zooming in. Default is zero.
method
Method used for calculating contribution scores. If mutagenesis, you can either set this to mutagenesis to visualize
in legacy way, or mutagenesis_letters to visualize an average of changes.
Expand Down Expand Up @@ -102,7 +105,13 @@ def contribution_scores(
if class_labels and not isinstance(class_labels, list):
class_labels = [str(class_labels)]
center = int(scores.shape[2] / 2)
start_idx = center - int(zoom_n_bases / 2)
start_idx = center - int(zoom_n_bases / 2) + x_shift
if start_idx < 0 or (start_idx + zoom_n_bases > scores.shape[2]):
raise ValueError(
f"Parameter x_shift={x_shift} with zoom={zoom_n_bases} "
f"gives invalid coordinates (start_idx={start_idx}, "
f"max={scores.shape[2]})."
)
scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :]

total_classes = scores.shape[1]
Expand Down
38 changes: 33 additions & 5 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import seaborn as sns
from loguru import logger
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Patch
from scipy.cluster.hierarchy import dendrogram, leaves_list, linkage

Expand Down Expand Up @@ -724,6 +725,7 @@ def clustermap_with_pwm_logos(
def selected_instances(
pattern_dict: dict,
idcs: list[int],
save_path: str = None,
) -> None:
"""
Plot the patterns specified by the indices in `idcs` from the `pattern_dict`.
Expand All @@ -735,6 +737,8 @@ def selected_instances(
contribution scores and metadata for the pattern. Refer to the output of `crested.tl.modisco.process_patterns`.
idcs
A list of indices specifying which patterns to plot. The indices correspond to keys in the `pattern_dict`.
save_path
File to save plot to.

See Also
--------
Expand All @@ -761,11 +765,13 @@ def selected_instances(
ax.set_title(pattern_dict[str(idx)]["pattern"]["id"])

plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches='tight')
plt.show()


def class_instances(
pattern_dict: dict, idx: int, class_representative: bool = False
pattern_dict: dict, idx: int, class_representative: bool = False, save_path: str = None,
) -> None:
"""
Plot instances of a specific pattern, either the representative pattern per class or all instances for a given pattern index.
Expand All @@ -780,6 +786,8 @@ def class_instances(
class_representative
If True, only the best representative instance of each class is plotted. If False (default), all instances of the pattern
within each class are plotted.
save_path
File to save plot to.

See Also
--------
Expand Down Expand Up @@ -814,6 +822,8 @@ def class_instances(
ax.set_title(pattern_dict[str(idx)][key][cl]["id"])

plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches='tight')
plt.show()


Expand Down Expand Up @@ -932,6 +942,7 @@ def clustermap_tf_motif(
save_path: str | None = None,
cluster_rows: bool = True,
cluster_columns: bool = True,
cbar_pad: float = 0.05,
) -> None:
"""
Generate a heatmap where one modality is represented as color, and the other as dot size.
Expand All @@ -958,6 +969,8 @@ def clustermap_tf_motif(
Whether to cluster the rows (classes). Default is True.
cluster_columns : bool
Whether to cluster the columns (patterns). Default is True.
cbar_pad : float
Horizontal padding between heatmap and colorbar in figure coordinates.

Examples
--------
Expand Down Expand Up @@ -1055,11 +1068,17 @@ def clustermap_tf_motif(
vmax=max(np.abs(heatmap_data.min()), np.abs(heatmap_data.max())),
)


# Define custom light-centered colormap
light_centered_cmap = LinearSegmentedColormap.from_list(
"light_coolwarm", ["blue", "#f0f0f0", "red"]
)

# Plot heatmap
heatmap = ax_heatmap.imshow(
heatmap_data,
aspect="auto",
cmap="coolwarm",
cmap=light_centered_cmap,
norm=norm,
)

Expand All @@ -1075,14 +1094,23 @@ def clustermap_tf_motif(
edgecolor="none",
)

# Add colorbar
cbar = plt.colorbar(heatmap, ax=ax_heatmap)
# Colorbar manual position
heat_pos = ax_heatmap.get_position()
cbar_width = 0.005
cbar_height = 0.25
cbar_x = heat_pos.x1 + cbar_pad
cbar_y = heat_pos.y0 + (heat_pos.height - cbar_height) / 2
cax = fig.add_axes([cbar_x, cbar_y, cbar_width, cbar_height])

# Colorbar draw
cbar = fig.colorbar(heatmap, cax=cax)
label = (
"Average pattern contribution score"
if heatmap_dim == "contrib"
else "Average TF expression, signed by activation/repression"
)
cbar.set_label(label)
cbar.set_label(label, labelpad=10)
cbar.ax.yaxis.set_tick_params(pad=5)

# Set axis labels and ticks
ax_heatmap.set_xticks(np.arange(data.shape[1]))
Expand Down
2 changes: 1 addition & 1 deletion src/crested/pl/scatter/_class_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _check_input_params():
scatter.set_rasterized(True) # Rasterize only the scatter points
plt.colorbar(scatter, ax=ax, label="Density")
else:
scatter = ax.scatter(x, y, edgecolor="k", alpha=alpha)
scatter = ax.scatter(x, y, edgecolor="k", alpha=alpha, rasterized=True)

ax.annotate(
f"Pearson: {pearson_corr:.2f}",
Expand Down
119 changes: 109 additions & 10 deletions src/crested/tl/modisco/_tfmodisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def match_to_patterns(
pattern_id: str,
pos_pattern: bool,
all_patterns: dict[str, dict[str, str | list[float]]],
sim_threshold: float = 0.5,
sim_threshold: float = 7.0,
ic_threshold: float = 0.15,
verbose: bool = False,
) -> dict:
Expand Down Expand Up @@ -395,13 +395,15 @@ def match_to_patterns(

p["class"] = cell_type

for pat_idx, pattern in enumerate(all_patterns.keys()):
sim = match_score_patterns(p, all_patterns[pattern]["pattern"])
if sim > sim_threshold:
match = True
if sim > max_sim:
max_sim = sim
match_idx = pat_idx
all_patterns_list = [pat['pattern'] for pat in all_patterns.values()]
sim_matrix1 = match_score_patterns(p, all_patterns_list)
sim_matrix2 = match_score_patterns(all_patterns_list, p).T # for some reason changing the order can give different results
sim_matrix = np.maximum(sim_matrix1, sim_matrix2)

max_sim = np.max(sim_matrix)
if max_sim > sim_threshold:
match = True
match_idx = np.argmax(sim_matrix[0])

if not match:
pattern_idx = len(all_patterns.keys())
Expand Down Expand Up @@ -442,6 +444,103 @@ def post_hoc_merging(
sim_threshold: float = 0.5,
ic_discard_threshold: float = 0.15,
verbose: bool = False,
return_info: bool = False,
) -> dict | tuple[dict, list[tuple[str, str, float]]]:
"""
Double-checks the similarity of all patterns and merges them if they exceed the threshold.

Filters out patterns with IC below the discard threshold at the last step and updates the keys.

Parameters
----------
all_patterns
Dictionary of all patterns with metadata. Each pattern must include 'pattern', 'ic', and 'classes'.
sim_threshold
Similarity threshold for merging patterns.
ic_discard_threshold
IC threshold below which patterns are discarded unless they belong to multiple classes.
verbose
Flag to enable verbose output of merged patterns.
return_info
If True, also return a list of all performed merges as (pattern_id_1, pattern_id_2, similarity).

Returns
-------
Updated patterns after merging and filtering with sequential keys.
If `return_info=True`, also returns a list of performed merges.
"""
current_meta = list(all_patterns.values())
all_merges = []
iteration = 0

while True:
iteration += 1
N = len(current_meta)

raw_patterns = [m["pattern"] for m in current_meta]
raw_ids = [m["pattern"]["id"] for m in current_meta]

S = match_score_patterns(raw_patterns, raw_patterns)
S = np.maximum(S, S.T)
np.fill_diagonal(S, -np.inf)

candidates = np.argwhere(S > sim_threshold)
candidates = [(i, j, S[i, j]) for i, j in candidates if i < j]

if not candidates:
if verbose:
print(f"Iteration {iteration}: no more matches above {sim_threshold}")
break

candidates.sort(key=lambda x: x[2], reverse=True)

matched = set()
merges = []
for i, j, score in candidates:
if i in matched or j in matched:
continue
matched.add(i)
matched.add(j)
merges.append((i, j, score))
all_merges.append((raw_ids[i], raw_ids[j], score))

if verbose:
print(f"Iteration {iteration}: performing {len(merges)} merges")
for i, j, score in merges:
print(f" -> merging {raw_ids[i]} + {raw_ids[j]} (sim={score:.3f})")

new_meta = []
used = set()
for i, j, _ in merges:
merged = merge_patterns(current_meta[i], current_meta[j])
new_meta.append(merged)
used.update({i, j})

for idx in range(N):
if idx not in used:
new_meta.append(current_meta[idx])

current_meta = new_meta

final = {}
idx = 0
for m in current_meta:
if m["ic"] >= ic_discard_threshold or len(m["classes"]) > 1:
final[str(idx)] = m
idx += 1
elif verbose:
print(f"Dropping {m['pattern']['id']} (IC={m['ic']:.3f})")

if verbose:
print(f"Done after {iteration} iterations; {len(final)} patterns remain.")

return (final, all_merges) if return_info else final

def post_hoc_merging_old(
all_patterns: dict,
sim_threshold: float = 0.5,
ic_discard_threshold: float = 0.15,
verbose: bool = False,
) -> dict:
"""
Double-checks the similarity of all patterns and merges them if they exceed the threshold.
Expand Down Expand Up @@ -867,8 +966,8 @@ def calculate_tomtom_similarity_per_pattern(

def process_patterns(
matched_files: dict[str, str | list[str] | None],
sim_threshold: float = 3.0,
trim_ic_threshold: float = 0.05,
sim_threshold: float = 6.0,
trim_ic_threshold: float = 0.025,
discard_ic_threshold: float = 0.1,
verbose: bool = False,
) -> dict[str, dict[str, str | list[float]]]:
Expand Down
Loading