Skip to content

Commit 5e1822e

Browse files
Merge pull request #15 from Multiomics-Analytics-Group/benchmark-full-stat
Feat: Add AQS, Precision metric, and iterative Refinement logic - ref…
2 parents 9687eb1 + 9823c1a commit 5e1822e

File tree

1 file changed

+154
-22
lines changed

1 file changed

+154
-22
lines changed

src/instanexus/assembly.py

Lines changed: 154 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
logging.basicConfig(level=logging.INFO)
4343
logger = logging.getLogger(__name__)
4444

45+
MAX_REFINE_ROUNDS = 10
46+
4547

4648
def find_peptide_overlaps(peptides, min_overlap):
4749
"""Finds overlaps between peptide sequences using a greedy approach."""
@@ -526,6 +528,111 @@ def rank_contigs_by_score(
526528
return sorted(scored, key=lambda s: (s.score, s.length), reverse=True)
527529

528530

531+
def build_overlap_graph(contigs: List[str], min_overlap: int) -> nx.DiGraph:
532+
"""
533+
Builds a directed graph where nodes are contigs and edges represent
534+
valid suffix-prefix overlaps.
535+
"""
536+
G = nx.DiGraph()
537+
for i, seq in enumerate(contigs):
538+
G.add_node(i, seq=seq, length=len(seq))
539+
540+
n_contigs = len(contigs)
541+
for i in range(n_contigs):
542+
for j in range(n_contigs):
543+
if i == j:
544+
continue
545+
546+
seq_a = contigs[i]
547+
seq_b = contigs[j]
548+
549+
max_ov = min(len(seq_a), len(seq_b))
550+
if max_ov < min_overlap:
551+
continue
552+
553+
best_overlap = 0
554+
for k in range(max_ov, min_overlap - 1, -1):
555+
if seq_a.endswith(seq_b[:k]):
556+
best_overlap = k
557+
break
558+
559+
if best_overlap > 0:
560+
G.add_edge(i, j, weight=best_overlap)
561+
562+
return G
563+
564+
565+
def merge_paths_from_overlap_graph(G: nx.DiGraph) -> List[str]:
566+
"""
567+
Traverses the overlap graph to find and merge the optimal paths
568+
(heaviest overlaps), resolving branches by prioritizing longer overlaps.
569+
"""
570+
merged_contigs = []
571+
G_work = G.copy()
572+
573+
while G_work.number_of_nodes() > 0:
574+
start_nodes = [n for n in G_work.nodes if G_work.in_degree(n) == 0]
575+
576+
if not start_nodes:
577+
start_node = max(G_work.nodes, key=lambda n: len(G_work.nodes[n]["seq"]))
578+
else:
579+
start_node = max(start_nodes, key=lambda n: len(G_work.nodes[n]["seq"]))
580+
581+
path = [start_node]
582+
current = start_node
583+
584+
while True:
585+
if G_work.out_degree(current) == 0:
586+
break
587+
588+
neighbors = list(G_work.successors(current))
589+
best_next = max(neighbors, key=lambda n: G_work[current][n]["weight"])
590+
591+
if best_next in path:
592+
break
593+
594+
path.append(best_next)
595+
current = best_next
596+
597+
if len(path) == 1:
598+
merged_contigs.append(G_work.nodes[start_node]["seq"])
599+
else:
600+
first_idx = path[0]
601+
super_seq = G_work.nodes[first_idx]["seq"]
602+
603+
for i in range(len(path) - 1):
604+
u, v = path[i], path[i + 1]
605+
overlap_len = G_work[u][v]["weight"]
606+
seq_v = G_work.nodes[v]["seq"]
607+
super_seq += seq_v[overlap_len:]
608+
609+
merged_contigs.append(super_seq)
610+
611+
G_work.remove_nodes_from(path)
612+
613+
return merged_contigs
614+
615+
616+
def refine_using_overlap_graph(contigs: List[str], min_overlap: int) -> List[str]:
617+
"""
618+
Wrapper function: Builds graph -> Merges paths -> Cleans up substrings.
619+
"""
620+
if not contigs:
621+
return []
622+
623+
G = build_overlap_graph(contigs, min_overlap)
624+
625+
refined = merge_paths_from_overlap_graph(G)
626+
627+
refined = sorted(list(set(refined)), key=len, reverse=True)
628+
final_set = []
629+
for seq in refined:
630+
if not any(seq in other and seq != other for other in refined):
631+
final_set.append(seq)
632+
633+
return final_set
634+
635+
529636
def scaffold_iterative_dbgx(
530637
seqs: List[str],
531638
kmer_size: int,
@@ -588,7 +695,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):
588695
seq = contig
589696
extended = True
590697

591-
# Extension forwards
592698
while extended:
593699
extended = False
594700
suffix = seq[-(k - 1) :]
@@ -597,7 +703,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):
597703

598704
successors = list(G.successors(suffix))
599705
if len(successors) > 1:
600-
# choose only if there is a clear dominant edge
601706
best_succ, best_w = None, 0
602707
for s in successors:
603708
w = G[suffix][s].get("weight", 0)
@@ -615,7 +720,6 @@ def extend_path_dbg(G, contig, k, min_weight=1):
615720
seq += nxt[-1]
616721
extended = True
617722

618-
# Extension backwards
619723
extended = True
620724
while extended:
621725
extended = False
@@ -779,10 +883,9 @@ def assemble_dbg(self, sequences):
779883
return scaffolds
780884

781885
def assemble_dbg_weighted(self, sequences: List[str]) -> List[str]:
782-
logger.info(
783-
f"[Assembler] Running DBG weighted (k={self.kmer_size}, min_weight={self.min_weight}, refine_rounds={self.refine_rounds})"
784-
)
886+
logger.info(f"[Assembler] Running DBG weighted (k={self.kmer_size}, min_weight={self.min_weight})")
785887

888+
# 1. Standard Weighted DBG Assembly
786889
kmers = get_kmers(sequences, self.kmer_size)
787890
if not kmers:
788891
logger.warning("No kmers generated; returning empty result.")
@@ -799,22 +902,38 @@ def assemble_dbg_weighted(self, sequences: List[str]) -> List[str]:
799902
ranked = rank_contigs_by_score(contigs_cp, self.alpha_len, self.alpha_cov, self.alpha_min)
800903
contigs = [r.seq for r in ranked]
801904

802-
if self.refine_rounds and self.refine_rounds > 0:
803-
contigs = scaffold_iterative_dbgx(
804-
contigs,
805-
kmer_size=self.kmer_size,
806-
size_threshold=self.size_threshold,
807-
min_weight=self.min_weight,
808-
max_rounds=self.refine_rounds,
809-
patience=self.refine_patience,
810-
alpha_len=self.alpha_len,
811-
alpha_cov=self.alpha_cov,
812-
alpha_min=self.alpha_min,
813-
)
905+
logger.info(f"DBG produced {len(contigs)} initial contigs.")
814906

815-
scaffolds = list(contigs)
907+
# 2. OVERLAP GRAPH REFINEMENT (The new logic)
908+
# Only runs if refine_rounds is > 0
909+
if self.refine_rounds > 0:
910+
logger.info("Refining contigs using Overlap Graph (Bird's Eye View)...")
816911

817-
return scaffolds
912+
# Use a slightly safer/larger overlap for this final merge to avoid false positives
913+
# e.g., max(min_overlap, 5) or just self.min_overlap
914+
safe_overlap = max(self.min_overlap, 5)
915+
iteration = 0
916+
917+
while iteration < self.refine_rounds:
918+
iteration += 1
919+
prev_count = len(contigs)
920+
921+
# Core logic
922+
contigs = refine_using_overlap_graph(contigs, min_overlap=safe_overlap)
923+
924+
new_count = len(contigs)
925+
926+
# CONVERGENZA: Se non abbiamo fuso nulla, ci fermiamo.
927+
if new_count >= prev_count:
928+
logger.info(f"Refinement converged at round {iteration}.")
929+
break
930+
931+
logger.info(f"Round {iteration}: reduced to {new_count} scaffolds.")
932+
933+
if iteration >= self.refine_rounds:
934+
logger.warning(f"Refinement stopped hit max rounds limit ({self.refine_rounds}).")
935+
936+
return contigs
818937

819938
def assemble_dbgX(self, sequences):
820939
logger.info(f"[Assembler] Running DBG-Extension (k={self.kmer_size})")
@@ -1096,6 +1215,7 @@ def main(
10961215
chain: str,
10971216
min_identity: float,
10981217
max_mismatches: int,
1218+
refine_rounds: int = 0,
10991219
):
11001220
"""Main function for standalone assembly."""
11011221

@@ -1138,6 +1258,7 @@ def main(
11381258
kmer_size=kmer_size,
11391259
min_identity=min_identity,
11401260
max_mismatches=max_mismatches,
1261+
refine_rounds=refine_rounds,
11411262
)
11421263

11431264
scaffolds = assembler.run(sequences=sequences, df_full=df)
@@ -1232,7 +1353,11 @@ def cli():
12321353
action="store_true",
12331354
help="Enables reference-based statistics.",
12341355
)
1235-
1356+
parser.add_argument(
1357+
"--refine",
1358+
action="store_true",
1359+
help="Enables iterative refinement (Overlap Graph) until convergence.",
1360+
)
12361361
parser.add_argument(
12371362
"--chain",
12381363
type=str,
@@ -1254,6 +1379,8 @@ def cli():
12541379

12551380
args = parser.parse_args()
12561381

1382+
refine_rounds_val = MAX_REFINE_ROUNDS if args.refine else 0
1383+
12571384
# in case of non-DBG mode, ignore kmer_size
12581385
if args.assembly_mode == "greedy":
12591386
args.kmer_size = 0
@@ -1262,7 +1389,12 @@ def cli():
12621389
if args.reference and not args.metadata_json_path:
12631390
parser.error("--metadata-json-path is required when --reference is enabled.")
12641391

1265-
main(**vars(args))
1392+
args_dict = vars(args)
1393+
1394+
if "refine" in args_dict:
1395+
del args_dict["refine"]
1396+
1397+
main(refine_rounds=refine_rounds_val, **args_dict)
12661398

12671399

12681400
if __name__ == "__main__":

0 commit comments

Comments
 (0)