Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions colabfold/plot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from pathlib import Path
import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


def plot_predicted_alignment_error(
jobname: str, num_models: int, outs: dict, result_dir: Path, show: bool = False
):
from matplotlib import pyplot as plt
plt.figure(figsize=(3 * num_models, 2), dpi=100)
for n, (model_name, value) in enumerate(outs.items()):
plt.subplot(1, num_models, n + 1)
Expand All @@ -18,17 +22,16 @@ def plot_predicted_alignment_error(


def plot_msa_v2(feature_dict, sort_lines=True, dpi=100):
from matplotlib import pyplot as plt
seq = feature_dict["msa"][0]
if "asym_id" in feature_dict:
Ls = [0]
k = feature_dict["asym_id"][0]
for i in feature_dict["asym_id"]:
if i == k: Ls[-1] += 1
else: Ls.append(1)
k = i
Ls = [0]
k = feature_dict["asym_id"][0]
for i in feature_dict["asym_id"]:
if i == k: Ls[-1] += 1
else: Ls.append(1)
k = i
else:
Ls = [len(seq)]
Ls = [len(seq)]
Ln = np.cumsum([0] + Ls)

try:
Expand Down Expand Up @@ -58,6 +61,7 @@ def plot_msa_v2(feature_dict, sort_lines=True, dpi=100):

Nn = np.cumsum(np.append(0,Nn))
lines = np.concatenate(lines,0)

plt.figure(figsize=(8,5), dpi=dpi)
plt.title("Sequence coverage")
plt.imshow(lines,
Expand All @@ -78,11 +82,11 @@ def plot_msa_v2(feature_dict, sort_lines=True, dpi=100):
return plt

def plot_msa(msa, query_sequence, seq_len_list, total_seq_len, dpi=100):
from matplotlib import pyplot as plt
# gather MSA info
prev_pos = 0
msa_parts = []
Ln = np.cumsum(np.append(0, [len for len in seq_len_list]))
# Corrección: Cambiado 'len' a 'l' para no sobrescribir la función integrada
Ln = np.cumsum(np.append(0, [l for l in seq_len_list]))
for id, l in enumerate(seq_len_list):
chain_seq = np.array(query_sequence[prev_pos : prev_pos + l])
chain_msa = np.array(msa[:, prev_pos : prev_pos + l])
Expand All @@ -97,6 +101,7 @@ def plot_msa(msa, query_sequence, seq_len_list, total_seq_len, dpi=100):
non_gaps[non_gaps == 0] = np.nan
msa_parts.append((non_gaps[:] * seqid[:, None]).tolist())
prev_pos += l

lines = []
lines_to_sort = []
prev_has_seq = [True] * len(seq_len_list)
Expand All @@ -120,12 +125,11 @@ def plot_msa(msa, query_sequence, seq_len_list, total_seq_len, dpi=100):
line += msa_parts[id][line_num]
lines_to_sort.append(line)
prev_has_seq = has_seq

lines_to_sort = np.array(lines_to_sort)
lines_to_sort = lines_to_sort[np.argsort(-np.nanmax(lines_to_sort, axis=1))]
lines += lines_to_sort.tolist()

# Nn = np.cumsum(np.append(0, Nn))
# lines = np.concatenate(lines, 1)
xaxis_size = len(lines[0])
yaxis_size = len(lines)

Expand All @@ -143,10 +147,6 @@ def plot_msa(msa, query_sequence, seq_len_list, total_seq_len, dpi=100):
)
for i in Ln[1:-1]:
plt.plot([i, i], [0, yaxis_size], color="black")
# for i in Ln_dash[1:-1]:
# plt.plot([i, i], [0, lines.shape[0]], "--", color="black")
# for j in Nn[1:-1]:
# plt.plot([0, lines.shape[1]], [j, j], color="black")

plt.plot((np.isnan(lines) == False).sum(0), color="black")
plt.xlim(0, xaxis_size)
Expand Down