Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 154 additions & 22 deletions src/instanexus/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MAX_REFINE_ROUNDS = 10


def find_peptide_overlaps(peptides, min_overlap):
"""Finds overlaps between peptide sequences using a greedy approach."""
Expand Down Expand Up @@ -526,6 +528,111 @@ def rank_contigs_by_score(
return sorted(scored, key=lambda s: (s.score, s.length), reverse=True)


def build_overlap_graph(contigs: List[str], min_overlap: int) -> nx.DiGraph:
"""
Builds a directed graph where nodes are contigs and edges represent
valid suffix-prefix overlaps.
"""
G = nx.DiGraph()
for i, seq in enumerate(contigs):
G.add_node(i, seq=seq, length=len(seq))

n_contigs = len(contigs)
for i in range(n_contigs):
for j in range(n_contigs):
if i == j:
continue

seq_a = contigs[i]
seq_b = contigs[j]

max_ov = min(len(seq_a), len(seq_b))
if max_ov < min_overlap:
continue

best_overlap = 0
for k in range(max_ov, min_overlap - 1, -1):
if seq_a.endswith(seq_b[:k]):
best_overlap = k
break

if best_overlap > 0:
G.add_edge(i, j, weight=best_overlap)

return G


def merge_paths_from_overlap_graph(G: nx.DiGraph) -> List[str]:
"""
Traverses the overlap graph to find and merge the optimal paths
(heaviest overlaps), resolving branches by prioritizing longer overlaps.
"""
merged_contigs = []
G_work = G.copy()

while G_work.number_of_nodes() > 0:
start_nodes = [n for n in G_work.nodes if G_work.in_degree(n) == 0]

if not start_nodes:
start_node = max(G_work.nodes, key=lambda n: len(G_work.nodes[n]["seq"]))
else:
start_node = max(start_nodes, key=lambda n: len(G_work.nodes[n]["seq"]))

path = [start_node]
current = start_node

while True:
if G_work.out_degree(current) == 0:
break

neighbors = list(G_work.successors(current))
best_next = max(neighbors, key=lambda n: G_work[current][n]["weight"])

if best_next in path:
break

path.append(best_next)
current = best_next

if len(path) == 1:
merged_contigs.append(G_work.nodes[start_node]["seq"])
else:
first_idx = path[0]
super_seq = G_work.nodes[first_idx]["seq"]

for i in range(len(path) - 1):
u, v = path[i], path[i + 1]
overlap_len = G_work[u][v]["weight"]
seq_v = G_work.nodes[v]["seq"]
super_seq += seq_v[overlap_len:]

merged_contigs.append(super_seq)

G_work.remove_nodes_from(path)

return merged_contigs


def refine_using_overlap_graph(contigs: List[str], min_overlap: int) -> List[str]:
"""
Wrapper function: Builds graph -> Merges paths -> Cleans up substrings.
"""
if not contigs:
return []

G = build_overlap_graph(contigs, min_overlap)

refined = merge_paths_from_overlap_graph(G)

refined = sorted(list(set(refined)), key=len, reverse=True)
final_set = []
for seq in refined:
if not any(seq in other and seq != other for other in refined):
final_set.append(seq)

return final_set


def scaffold_iterative_dbgx(
seqs: List[str],
kmer_size: int,
Expand Down Expand Up @@ -588,7 +695,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):
seq = contig
extended = True

# Extension forwards
while extended:
extended = False
suffix = seq[-(k - 1) :]
Expand All @@ -597,7 +703,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):

successors = list(G.successors(suffix))
if len(successors) > 1:
# choose only if there is a clear dominant edge
best_succ, best_w = None, 0
for s in successors:
w = G[suffix][s].get("weight", 0)
Expand All @@ -615,7 +720,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):
seq += nxt[-1]
extended = True

# Extension backwards
extended = True
while extended:
extended = False
Expand Down Expand Up @@ -779,10 +883,9 @@ def assemble_dbg(self, sequences):
return scaffolds

def assemble_dbg_weighted(self, sequences: List[str]) -> List[str]:
logger.info(
f"[Assembler] Running DBG weighted (k={self.kmer_size}, min_weight={self.min_weight}, refine_rounds={self.refine_rounds})"
)
logger.info(f"[Assembler] Running DBG weighted (k={self.kmer_size}, min_weight={self.min_weight})")

# 1. Standard Weighted DBG Assembly
kmers = get_kmers(sequences, self.kmer_size)
if not kmers:
logger.warning("No kmers generated; returning empty result.")
Expand All @@ -799,22 +902,38 @@ def assemble_dbg_weighted(self, sequences: List[str]) -> List[str]:
ranked = rank_contigs_by_score(contigs_cp, self.alpha_len, self.alpha_cov, self.alpha_min)
contigs = [r.seq for r in ranked]

if self.refine_rounds and self.refine_rounds > 0:
contigs = scaffold_iterative_dbgx(
contigs,
kmer_size=self.kmer_size,
size_threshold=self.size_threshold,
min_weight=self.min_weight,
max_rounds=self.refine_rounds,
patience=self.refine_patience,
alpha_len=self.alpha_len,
alpha_cov=self.alpha_cov,
alpha_min=self.alpha_min,
)
logger.info(f"DBG produced {len(contigs)} initial contigs.")

scaffolds = list(contigs)
# 2. OVERLAP GRAPH REFINEMENT (The new logic)
# Only runs if refine_rounds is > 0
if self.refine_rounds > 0:
logger.info("Refining contigs using Overlap Graph (Bird's Eye View)...")

return scaffolds
# Use a slightly safer/larger overlap for this final merge to avoid false positives
# e.g., max(min_overlap, 5) or just self.min_overlap
safe_overlap = max(self.min_overlap, 5)
iteration = 0

while iteration < self.refine_rounds:
iteration += 1
prev_count = len(contigs)

# Core logic
contigs = refine_using_overlap_graph(contigs, min_overlap=safe_overlap)

new_count = len(contigs)

# CONVERGENZA: Se non abbiamo fuso nulla, ci fermiamo.
if new_count >= prev_count:
logger.info(f"Refinement converged at round {iteration}.")
break

logger.info(f"Round {iteration}: reduced to {new_count} scaffolds.")

if iteration >= self.refine_rounds:
logger.warning(f"Refinement stopped hit max rounds limit ({self.refine_rounds}).")

return contigs

def assemble_dbgX(self, sequences):
logger.info(f"[Assembler] Running DBG-Extension (k={self.kmer_size})")
Expand Down Expand Up @@ -1096,6 +1215,7 @@ def main(
chain: str,
min_identity: float,
max_mismatches: int,
refine_rounds: int = 0,
):
"""Main function for standalone assembly."""

Expand Down Expand Up @@ -1138,6 +1258,7 @@ def main(
kmer_size=kmer_size,
min_identity=min_identity,
max_mismatches=max_mismatches,
refine_rounds=refine_rounds,
)

scaffolds = assembler.run(sequences=sequences, df_full=df)
Expand Down Expand Up @@ -1232,7 +1353,11 @@ def cli():
action="store_true",
help="Enables reference-based statistics.",
)

parser.add_argument(
"--refine",
action="store_true",
help="Enables iterative refinement (Overlap Graph) until convergence.",
)
parser.add_argument(
"--chain",
type=str,
Expand All @@ -1254,6 +1379,8 @@ def cli():

args = parser.parse_args()

refine_rounds_val = MAX_REFINE_ROUNDS if args.refine else 0

# in case of non-DBG mode, ignore kmer_size
if args.assembly_mode == "greedy":
args.kmer_size = 0
Expand All @@ -1262,7 +1389,12 @@ def cli():
if args.reference and not args.metadata_json_path:
parser.error("--metadata-json-path is required when --reference is enabled.")

main(**vars(args))
args_dict = vars(args)

if "refine" in args_dict:
del args_dict["refine"]

main(refine_rounds=refine_rounds_val, **args_dict)


if __name__ == "__main__":
Expand Down