diff --git a/config/config.yaml b/config/config.yaml index 94e098e..a44368f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,15 +4,6 @@ # a reasonable value might be 4, or 8. cpu_cores: 4 -# Used for scatter/gather processing, corresponds with the number of families within the taxon defined in -# fasta_filter that have records for the specified marker. In practice, this means that the input BCDM TSV -# file has to have exactly this many distinct (not empty) values for the `family` column where the column -# named under fasta_filter.filter_level has has value fasta_filter.filter_name (e.g. the `order` must be -# `Odonata`. TODO: make this so that it is done automatically from the data. At present, this needs to be -# calculated by the user from the input file, e.g. by first making the sqlite database, or by grepping the -# TSV file somehow. -nfamilies: 63 - # Number of outgroups to include in each family-level analysis. Minimum is 2. outgroups: 3 @@ -46,8 +37,13 @@ datatype: NT # Choose which records to use from the database for the pipeline. filter_name only takes one name, so does filter level. # filter levels: class, order, family, genus, all (no filter) fasta_filter: + # max level: family filter_level: kingdom filter_name: Animalia + # Maximum number of sequences per fasta + # if above, fasta files are generated not at the family level but at lower rank i.e., genus or, if still + # too many sequences, at species level + max_fasta_seq: 200 name: phylogeny @@ -61,6 +57,6 @@ file_names: bold_tsv: resources/BOLD_Public.21-Jun-2024.curated.NL.tsv open_tre: resources/opentree/opentree14.9_tree/labelled_supertree/labelled_supertree.tre hmm: resources/hmm/COI-5P.hmm - fasta_dir: results/fasta/family + fasta_dir: results/fasta/ blast_dir: results/blast diff --git a/results/fasta/family/README.md b/results/fasta/family/README.md deleted file mode 100644 index 00f60aa..0000000 --- a/results/fasta/family/README.md +++ /dev/null @@ -1 +0,0 @@ -This is a placeholder for a folder structure where FASTA files, alignments, and trees are written as intermediate results. diff --git a/workflow/Snakefile b/workflow/Snakefile index 7b192e5..9716ce0 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -1,3 +1,5 @@ +import os +from pathlib import Path configfile: "config/config.yaml" # TODO refactor the snakefile so that all dependencies are managed by a containerized mamba. @@ -5,9 +7,29 @@ configfile: "config/config.yaml" # container: "docker://condaforge/mambaforge:23.1.0-1" +wildcard_constraints: + taxon=r"[\w-]+" + +def get_all_taxon(wildcards): + checkpoint_output = checkpoints.family_fasta.get().output[0] + taxon_dir = os.path.join(Path(checkpoint_output).parent, "taxon") + taxon = [ + t for t in os.listdir(taxon_dir) + if os.path.isdir(os.path.join(taxon_dir, t)) + and os.stat(os.path.join(taxon_dir, t, "unaligned.fa")).st_size != 0 # ignore empty files + ] + return taxon + +def get_prep_raxml_backbone_input(wildcards): + return expand( + "results/fasta/taxon/{{taxon}}/exemplars.fa".format(config['datatype']), + taxon=get_all_taxon(wildcards) + ) + rule all: input: - "results/grafted.tre" + "results/passed_taxa.txt", + "results/fasta/raxml-ready.fa.raxml.bestTree" # Creates and populates a SQLite database with filtered sequence records. # Uses BOLD dump TSV as defined in config file @@ -82,25 +104,20 @@ rule megatree_loader: rm {params.tempsql} {params.tempdb} """ -# TODO either this is updated dynamically or scattergather is abandoned for dynamic() -scattergather: - split = config["nfamilies"] - -# Exports unaligned BIN sequences for each family within the total set. This task is parallelized as many times as -# specified by config["nfamilies"]. During this task, the longest raw sequence within each BIN is selected. -rule family_fasta: - input: rules.map_opentol.output +# Exports unaligned BIN sequences for each family within the total set. This task is parallelized for each taxon. +# During this task, the longest raw sequence within each BIN is selected. +checkpoint family_fasta: + input: 'results/databases/megatree_loader.ok' output: - fastas=scatter.split("results/fasta/family/{scatteritem}/unaligned.fa"), - status=f'{config["file_names"]["fasta_dir"]}/family_fasta.ok' + fasta_tsv="results/fasta/taxon_fasta.tsv", params: log_level=config['log_level'], fasta_dir=config["file_names"]["fasta_dir"], filter_level=config["fasta_filter"]["filter_level"], filter_name=config["fasta_filter"]["filter_name"], + max_seq_fasta=config["fasta_filter"]["max_fasta_seq"], marker=config["marker"], db="results/databases/BOLD_{}_barcodes.db".format(config["marker"]), - chunks=config["nfamilies"], conda: "envs/family_fasta.yml" log: "logs/family_fasta/family_fasta.log" benchmark: @@ -112,10 +129,9 @@ rule family_fasta: -f {params.fasta_dir} \ -l {params.filter_level} \ -n {params.filter_name} \ - -c {params.chunks} \ + -L {params.max_seq_fasta} \ -m {params.marker} \ -v {params.log_level} 2> {log} - touch {output.status} """ # Creates a local BLAST database from the exported sequences that have OTT IDs. Later on, this database will be used @@ -123,59 +139,106 @@ rule family_fasta: # hopefully mean that even monotypic families will have a constraint tree with >= 3 tips so all chunks can be treated # identically. rule makeblastdb: - input: rules.family_fasta.output.status - output: f'{config["file_names"]["blast_dir"]}/{config["blastdb"]}.nsq' + input: + fasta_dir="results/fasta/taxon" + output: + database="results/blast/blastdb.nsq" params: - chunks=config["nfamilies"], - fasta_dir=config["file_names"]["fasta_dir"], - concatenated=f'{config["file_names"]["blast_dir"]}/{config["blastdb"]}.fa', - blastdb=f'{config["file_names"]["blast_dir"]}/{config["blastdb"]}' + tmp_file="results/blast/tmp_sequences.fa", + database="results/blast/blastdb" conda: "envs/blast.yml" log: "logs/makeblastdb/makeblastdb.log" - benchmark: - "benchmarks/makeblastdb.benchmark.txt" + benchmark: "benchmarks/makeblastdb.benchmark.txt" shell: """ - sh workflow/scripts/makeblastdb.sh \ - {params.blastdb} {params.fasta_dir} {params.chunks} {params.concatenated} 2> {log} + python workflow/scripts/makeblastdb.py \ + --fasta_dir {input.fasta_dir} \ + --tmp_file {params.tmp_file} \ + --database {params.database} \ + > {log} 2>&1 """ - - # Gets the nearest outgroups by blasting. The number of outgroups is defined in config["outgroups"]. The outgroups are # selected by querying every sequence in the ingroup for the top 10 hits, then taking the most frequently occurring # hits across all queries. rule get_outgroups: input: - unaligned = "results/fasta/family/{scatteritem}/unaligned.fa", + unaligned = "results/fasta/taxon/{taxon}/unaligned.fa", makeblastdb = rules.makeblastdb.output - output: "results/fasta/family/{scatteritem}/outgroups.fa", + output: "results/fasta/taxon/{taxon}/outgroups.fa", params: outgroups = config["outgroups"], blastdb=f'{config["file_names"]["blast_dir"]}/{config["blastdb"]}' conda: "envs/blast.yml" - log: "logs/get_outgroups/get_outgroups-{scatteritem}.log" + log: "logs/get_outgroups/get_outgroups-{taxon}.log" benchmark: - "benchmarks/get_outgroups/get_outgroups-{scatteritem}.benchmark.txt" + "benchmarks/get_outgroups/get_outgroups-{taxon}.benchmark.txt" shell: """ sh workflow/scripts/get_outgroups.sh {params.blastdb} {params.outgroups} {input.unaligned} {output} 2> {log} """ +# Count the number of sequences in each family.fasta file. +rule count_sequences: + input: + "results/fasta/taxon/{taxon}/unaligned.fa", + "results/fasta/taxon/{taxon}/outgroups.fa" + output: "results/fasta/taxon/{taxon}/sequences_count.txt" + run: + count = sum(1 for line in open(input[0]) if line.startswith(">")) + with open(output[0], "w") as f: + f.write(f"{count}\n") + +def get_list_taxa_file_input(wildcards): + return expand("results/fasta/taxon/{taxon}/sequences_count.txt", taxon=get_all_taxon(wildcards)) + +# Generate a list of taxa that passed the condition of having at least 3 sequences. +checkpoint list_taxa_files: + input: get_list_taxa_file_input + output: + passed="results/passed_taxa.txt" + run: + passed_taxa = [] + + for file in input: + taxon = file.split("/")[3] # Extract taxon from path + with open(file, "r") as f: + num_sequences = int(f.read().strip()) + if num_sequences >= 3: + passed_taxa.append(taxon) + + with open(output.passed, "w") as passed_f: + for taxon in passed_taxa: + passed_f.write(f"{taxon}\n") + +# Definition that gets the list of taxa that passed the condition from the checkpoint output. +def get_passed_taxa(wildcards): + checkpoint_output = checkpoints.list_taxa_files.get().output.passed + with open(checkpoint_output, "r") as passed_f: + passed_taxa = {line.strip() for line in passed_f.readlines()} + return passed_taxa + +# Definition that gets the list of taxa that failed the condition from the checkpoint output. +def get_failed_taxa(wildcards): + checkpoint_output = checkpoints.list_taxa_files.get().output.passed + with open(checkpoint_output, "r") as passed_f: + passed_taxa = {line.strip() for line in passed_f.readlines()} + failed_taxa = {taxon for taxon in get_all_taxon(wildcards) if taxon not in passed_taxa} + return failed_taxa # Exports OpenToL newick file for each unaligned BIN sequence file. This implementation uses the induced subtree web # service endpoint from OpenToL. The new implementation appears to achieve better coverage. rule family_constraint: input: - unaligned = "results/fasta/family/{scatteritem}/unaligned.fa", - outgroups = "results/fasta/family/{scatteritem}/outgroups.fa" - output: "results/fasta/family/{scatteritem}/constraint.tre" + unaligned = "results/fasta/taxon/{taxon}/unaligned.fa", + outgroups = "results/fasta/taxon/{taxon}/outgroups.fa" + output: "results/fasta/taxon/{taxon}/constraint.tre" params: log_level=config['log_level'], db="results/databases/BOLD_{}_barcodes.db".format(config["marker"]) conda: "envs/family_constraint.yml" - log: "logs/family_constraint/family_constraint-{scatteritem}.log" + log: "logs/family_constraint/family_constraint-{taxon}.log" benchmark: - "benchmarks/family_constraint/family_constraint-{scatteritem}.benchmark.txt" + "benchmarks/family_constraint/family_constraint-{taxon}.benchmark.txt" shell: """ python workflow/scripts/family_constraint.py \ @@ -190,17 +253,17 @@ rule family_constraint: # is not needed at all because so far 0 revcom sequences were observed in BOLD. rule msa_hmm: input: - ingroup="results/fasta/family/{scatteritem}/unaligned.fa", - outgroup="results/fasta/family/{scatteritem}/outgroups.fa" - output: "results/fasta/family/{scatteritem}/aligned.fa" + ingroup="results/fasta/taxon/{taxon}/unaligned.fa", + outgroup="results/fasta/taxon/{taxon}/outgroups.fa" + output: "results/fasta/taxon/{taxon}/aligned.fa" params: log_level=config['log_level'], hmm_file=config['file_names']['hmm'], db="results/databases/BOLD_{}_barcodes.db".format(config["marker"]) conda: "envs/msa_hmm.yml" - log: "logs/msa_hmm/msa_hmm-{scatteritem}.log" + log: "logs/msa_hmm/msa_hmm-{taxon}.log" benchmark: - "benchmarks/msa_hmm/msa_hmm-{scatteritem}.benchmark.txt" + "benchmarks/msa_hmm/msa_hmm-{taxon}.benchmark.txt" shell: """ python workflow/scripts/msa_hmm.py \ @@ -221,17 +284,17 @@ rule msa_hmm: # criterion should be that the candidate outgroups are in the OpenToL. rule prep_raxml: input: - alignment="results/fasta/family/{scatteritem}/aligned.fa", - tree="results/fasta/family/{scatteritem}/constraint.tre" + alignment="results/fasta/taxon/{taxon}/aligned.fa", + tree="results/fasta/taxon/{taxon}/constraint.tre" output: - tree="results/fasta/family/{scatteritem}/remapped.tre" + tree="results/fasta/taxon/{taxon}/remapped.tre" params: db="results/databases/BOLD_{}_barcodes.db".format(config["marker"]), log_level=config['log_level'] conda: "envs/prep_raxml.yml" - log: "logs/prep_raxml/prep_raxml-{scatteritem}.log" + log: "logs/prep_raxml/prep_raxml-{taxon}.log" benchmark: - "benchmarks/prep_raxml/prep_raxml-{scatteritem}.benchmark.txt" + "benchmarks/prep_raxml/prep_raxml-{taxon}.benchmark.txt" shell: """ python workflow/scripts/prep_raxml.py \ @@ -242,49 +305,31 @@ rule prep_raxml: -d {params.db} 2> {log} """ -# Runs raxml as a constrained tree search. Possibly this should instead take a two-step approach, where unplaced -# sequences are first placed, and then the branch lengths are estimated. This requires that there actually is a -# constraint tree. Currently, this rule deals with the fact that some of the constraint trees are 0-byte files due -# to incomplete OpenToL coverage by trapping raxml-ng errors and then re-running without the constraint. +# Runs raxml as a constrained tree search. This rule first checks if the alignment file exists and is non-empty. +# It then extracts the outgroup names from the alignment file. If the constraint tree file exists, is non-empty, +# contains inner nodes, and is not fully resolved, RAxML-NG is run with the constraint tree. If the constraint tree +# is fully-resolved, RAxML-NG is run without it. The output is the best tree from the RAxML-NG run. rule run_raxml: input: - alignment = "results/fasta/family/{scatteritem}/aligned.fa", - tree = "results/fasta/family/{scatteritem}/remapped.tre" + alignment="results/fasta/taxon/{taxon}/aligned.fa", + tree="results/fasta/taxon/{taxon}/remapped.tre" output: - tree = "results/fasta/family/{scatteritem}/aligned.fa.raxml.bestTree" + tree="results/fasta/taxon/{taxon}/aligned.fa.raxml.bestTree" params: - model = config['model'], - num_outgroups= config['outgroups'] - log: "logs/run_raxml/run_raxml-{scatteritem}.log" - benchmark: - "benchmarks/run_raxml/run_raxml-{scatteritem}.benchmark.txt" + model=config['model'], + num_outgroups=config['outgroups'] + log: "logs/run_raxml/run_raxml-{taxon}.log" + benchmark: "benchmarks/run_raxml/run_raxml-{taxon}.benchmark.txt" conda: "envs/raxml.yml" shell: """ - if [ -s {input.alignment} ]; then - OG=$(grep '>' {input.alignment} | tail -{params.num_outgroups} | sed -e 's/>//' | tr '\n' ',') - if [ -s {input.tree} ]; then - set -e - raxml-ng \ - --redo \ - --msa {input.alignment} \ - --outgroup $OG \ - --model {params.model} \ - --tree-constraint {input.tree} \ - --search1 > {log} 2>&1 \ - || \ - raxml-ng \ - --redo \ - --msa {input.alignment} \ - --outgroup $OG \ - --model {params.model} \ - --search1 >> {log} 2>&1 - else - raxml-ng --msa {input.alignment} --outgroup $OG --model {params.model} --search --redo > {log} 2>&1 - fi - else - touch {output.tree} - fi + python workflow/scripts/run_raxml.py \ + --alignment {input.alignment} \ + --tree {input.tree} \ + --output {output.tree} \ + --model {params.model} \ + --num_outgroups {params.num_outgroups} \ + --log_file {log} """ # Reroots the raxml output in and then prunes the outgroups. Rooting is done by finding the smallest clade subtended @@ -292,17 +337,17 @@ rule run_raxml: # clade members and then rooting on that bipartition branch. Subsequently, the outgroup taxa are removed. rule reroot_raxml_output: input: - tree = "results/fasta/family/{scatteritem}/aligned.fa.raxml.bestTree", - constraint = "results/fasta/family/{scatteritem}/remapped.tre", - alignment = "results/fasta/family/{scatteritem}/aligned.fa" + tree = "results/fasta/taxon/{taxon}/aligned.fa.raxml.bestTree", + constraint = "results/fasta/taxon/{taxon}/remapped.tre", + alignment = "results/fasta/taxon/{taxon}/aligned.fa" output: - outtree = "results/fasta/family/{scatteritem}/aligned.fa.raxml.bestTree.rooted" + outtree = "results/fasta/taxon/{taxon}/aligned.fa.raxml.bestTree.rooted" params: log_level = config['log_level'], num_outgroups = config['outgroups'] - log: "logs/reroot_raxml_output/reroot_raxml_output-{scatteritem}.log" + log: "logs/reroot_raxml_output/reroot_raxml_output-{taxon}.log" benchmark: - "benchmarks/reroot_raxml_output/reroot_raxml_output-{scatteritem}.benchmark.txt" + "benchmarks/reroot_raxml_output/reroot_raxml_output-{taxon}.benchmark.txt" conda: "envs/reroot_backbone.yml" shell: """ @@ -325,28 +370,44 @@ rule reroot_raxml_output: # 3. the ones closest to the median root-to-tip path length # Empirically, the first option yields rescaled branch lengths that are closest # to those optimized freely on a total tree. +# The exemplar choosing script is run only if the alignment and tree files exist. +# Otherwise, the outgroups are removed from the input aligned file, which is then used as the output. rule choose_exemplars: input: - alignment = "results/fasta/family/{scatteritem}/aligned.fa", - tree= "results/fasta/family/{scatteritem}/aligned.fa.raxml.bestTree.rooted" - output: "results/fasta/family/{scatteritem}/exemplars.fa" + alignment=lambda wildcards: f"results/fasta/taxon/{wildcards.taxon}/aligned.fa" if wildcards.taxon in get_passed_taxa(wildcards) else [], + tree=lambda wildcards: f"results/fasta/taxon/{wildcards.taxon}/aligned.fa.raxml.bestTree.rooted" if wildcards.taxon in get_passed_taxa(wildcards) else [], + failed_file=lambda wildcards: f"results/fasta/taxon/{wildcards.taxon}/aligned.fa" if wildcards.taxon in get_failed_taxa(wildcards) else [], + failed_outgroup=lambda wildcards: f"results/fasta/taxon/{wildcards.taxon}/outgroups.fa" if wildcards.taxon in get_failed_taxa(wildcards) else [] + output: + "results/fasta/taxon/{taxon}/exemplars.fa" params: - log_level = config['log_level'], - strategy = 'median' - log: "logs/choose_exemplars/choose_exemplars-{scatteritem}.log" + log_level=config['log_level'], + strategy='median' + log: + "logs/choose_exemplars/choose_exemplars-{taxon}.log" benchmark: - "benchmarks/choose_exemplars/choose_exemplars-{scatteritem}.benchmark.txt" - conda: "envs/choose_exemplars.yaml" + "benchmarks/choose_exemplars/choose_exemplars-{taxon}.benchmark.txt" + conda: + "envs/choose_exemplars.yaml" shell: """ - if [ -s {input.alignment} ]; then + # If the failed file exists and is not empty, run the process_failed_file.py script + if [ -n "{input.failed_file}" ] && [ -s {input.failed_file} ]; then + echo "Failed file exists, processing with process_failed_file.py." >> {log} + python workflow/scripts/process_failed_file.py \ + --failed {input.failed_file} \ + --outgroup {input.failed_outgroup} \ + --output {output} 2>> {log} + # If the failed file doesn't exist and the tree file is available, run the choose_exemplars.py script + elif [ -s {input.tree} ] && [ -s {input.alignment} ]; then python workflow/scripts/choose_exemplars.py \ -v {params.log_level} \ -t {input.tree} \ -i {input.alignment} \ -s {params.strategy} \ - -o {output} 2> {log} + -o {output} 2>> {log} else + echo "No valid input files found for taxon {wildcards.taxon}." >> {log} touch {output} fi """ @@ -358,13 +419,14 @@ rule choose_exemplars: # data on the command line, awk turns it into FASTA. rule prep_raxml_backbone: input: - fastas=gather.split("results/fasta/family/{scatteritem}/exemplars.fa"), + fastas=get_prep_raxml_backbone_input, opentol='results/databases/megatree_loader.ok' output: fasta="results/fasta/raxml-ready.fa", tree="results/fasta/raxml-ready.tre", extinct="results/fasta/extinct_pids.txt" params: + input_list="results/fasta/exemplars_list.txt", db="results/databases/BOLD_{}_barcodes.db".format(config["marker"]), log_level = config['log_level'], hmm_file=config['file_names']['hmm'] @@ -372,35 +434,25 @@ rule prep_raxml_backbone: conda: "envs/prep_raxml_backbone.yml" shell: """ - # Generate constraint tree and list of extinct PIDs - python workflow/scripts/backbone_constraint.py \ + ls results/fasta/taxon/*/exemplars.fa > results/fasta/exemplars_list.txt + + python workflow/scripts/prep_raxml_backbone.py \ -d {params.db} \ -v {params.log_level} \ - -i '{input.fastas}' \ + -i {params.input_list} \ -e {output.extinct} \ - -o {output.tree} 2> {log} - - # Clean the concatenated FASTA by removing gaps (dashes) - sed '/^>/! s/-//g' {input.fastas} > results/fasta/unaligned.fa - - # Align with hmmalign and output in Stockholm format - hmmalign --trim --dna --informat FASTA --outformat Stockholm {params.hmm_file} results/fasta/unaligned.fa > results/fasta/aligned.sto - - # Convert the Stockholm alignment to a non-interleaved FASTA format for RAxML - seqmagick convert results/fasta/aligned.sto {output.fasta} - - # Remove any extinct PIDs - [ -e {output.extinct} ] && seqmagick mogrify --exclude-from-file {output.extinct} {output.fasta} + -o {output.tree} \ + -hmm {params.hmm_file} \ + -f {output.fasta} 2> {log} """ - # Constructs the backbone topology under ML using raxml-ng with a tree constraint that # is intended to keep all pairs of exemplars monophyletic. Here, no outgroups are # specified, because in principle this step could be dealing with the entire # taxonomic width of BOLD. Instead, the tree will be rooted using the constraint. rule run_raxml_backbone: input: - alignment = "results/fasta/raxml-ready.fa", + alignment = "results/fasta/raxml-ready.fa.raxml.reduced.phy", tree = "results/fasta/raxml-ready.tre" output: tree = "results/fasta/raxml-ready.fa.raxml.bestTree" @@ -412,12 +464,11 @@ rule run_raxml_backbone: conda: "envs/raxml.yml" shell: """ - raxml-ng \ - --redo \ - --msa {input.alignment} \ - --model {params.model} \ - --tree-constraint {input.tree} \ - --search > {log} 2>&1 + python workflow/scripts/run_raxml_backbone.py \ + {input.alignment} \ + {input.tree} \ + {params.model} \ + {log} """ # Reroots the backbone using the constraint tree. The basic method is to find the @@ -446,7 +497,7 @@ rule reroot_backbone: -v {params.log_level} 2> {log} """ -# Grafts the individual family clades onto the backbone. This is done by taking the +# Grafts the individual taxon clades onto the backbone. This is done by taking the # backbone tree and replacing the tips that are in the clade with the clade tree. rule graft_clades: input: @@ -457,7 +508,6 @@ rule graft_clades: tree = "results/grafted.tre" params: log_level = config['log_level'], - nfamilies = config["nfamilies"] log: "logs/graft_clades/graft_clades.log" benchmark: "benchmarks/graft_clades.benchmark.txt" @@ -469,6 +519,5 @@ rule graft_clades: -f {input.clades} \ -e {input.extinct} \ -o {output.tree} \ - -n {params.nfamilies} \ -v {params.log_level} 2> {log} - """ + """ \ No newline at end of file diff --git a/workflow/envs/blast.txt b/workflow/envs/blast.txt index 904040e..7d9e092 100644 --- a/workflow/envs/blast.txt +++ b/workflow/envs/blast.txt @@ -1 +1,4 @@ -biopython \ No newline at end of file +os +subprocess +argparse +re \ No newline at end of file diff --git a/workflow/envs/choose_exemplars.txt b/workflow/envs/choose_exemplars.txt index 2f988d8..7755d47 100644 --- a/workflow/envs/choose_exemplars.txt +++ b/workflow/envs/choose_exemplars.txt @@ -1,3 +1,4 @@ -db-sqlite3 biopython -dendropy \ No newline at end of file +argparse +dendropy +statistics \ No newline at end of file diff --git a/workflow/envs/create_database.txt b/workflow/envs/create_database.txt index 3cb1ac1..a6f8136 100644 --- a/workflow/envs/create_database.txt +++ b/workflow/envs/create_database.txt @@ -1 +1,3 @@ -db-sqlite3 \ No newline at end of file +util +argparse +subprocess \ No newline at end of file diff --git a/workflow/envs/family_constraint.txt b/workflow/envs/family_constraint.txt index 9e41ed0..8de6a20 100644 --- a/workflow/envs/family_constraint.txt +++ b/workflow/envs/family_constraint.txt @@ -1,2 +1,3 @@ -requests -dendropy \ No newline at end of file +argparse +util +opentol \ No newline at end of file diff --git a/workflow/envs/family_constraint.yml b/workflow/envs/family_constraint.yml index 23d493b..c135d9f 100644 --- a/workflow/envs/family_constraint.yml +++ b/workflow/envs/family_constraint.yml @@ -4,6 +4,7 @@ channels: - bioconda - defaults dependencies: + - sqlite - pip - pip: - -r ../../workflow/envs/family_constraint.txt diff --git a/workflow/envs/family_fasta.txt b/workflow/envs/family_fasta.txt index 547d7c1..36b0229 100644 --- a/workflow/envs/family_fasta.txt +++ b/workflow/envs/family_fasta.txt @@ -1,2 +1,5 @@ -db-sqlite3 -pandas \ No newline at end of file +pandas +argparse +util +re +os \ No newline at end of file diff --git a/workflow/envs/family_fasta.yml b/workflow/envs/family_fasta.yml index 11ed8ab..483873d 100644 --- a/workflow/envs/family_fasta.yml +++ b/workflow/envs/family_fasta.yml @@ -4,6 +4,7 @@ channels: - bioconda - defaults dependencies: + - sqlite - pip - pip: - -r ../../workflow/envs/family_fasta.txt diff --git a/workflow/envs/graft_clades.txt b/workflow/envs/graft_clades.txt index 4e6ab3f..02dbd63 100644 --- a/workflow/envs/graft_clades.txt +++ b/workflow/envs/graft_clades.txt @@ -1,2 +1,4 @@ +argparse dendropy -biopython \ No newline at end of file +os +util \ No newline at end of file diff --git a/workflow/envs/map_opentol.txt b/workflow/envs/map_opentol.txt index 34bd7ac..32bb499 100644 --- a/workflow/envs/map_opentol.txt +++ b/workflow/envs/map_opentol.txt @@ -1,4 +1,6 @@ -db-sqlite3 -pandas -numpy -requests \ No newline at end of file +biopython +logging +tempfile +argparse +os +util \ No newline at end of file diff --git a/workflow/envs/prep_raxml.txt b/workflow/envs/prep_raxml.txt index 5301f42..22a8a64 100644 --- a/workflow/envs/prep_raxml.txt +++ b/workflow/envs/prep_raxml.txt @@ -1,2 +1,4 @@ +argparse +os biopython -db-sqlite3 \ No newline at end of file +util \ No newline at end of file diff --git a/workflow/envs/prep_raxml.yml b/workflow/envs/prep_raxml.yml index 50b760e..33254b0 100644 --- a/workflow/envs/prep_raxml.yml +++ b/workflow/envs/prep_raxml.yml @@ -4,6 +4,7 @@ channels: - bioconda - defaults dependencies: + - sqlite - pip - pip: - - -r ../../workflow/envs/prep_raxml.txt + - -r ../../workflow/envs/prep_raxml.txt \ No newline at end of file diff --git a/workflow/envs/prep_raxml_backbone.txt b/workflow/envs/prep_raxml_backbone.txt index a8c2028..e913023 100644 --- a/workflow/envs/prep_raxml_backbone.txt +++ b/workflow/envs/prep_raxml_backbone.txt @@ -1,4 +1,6 @@ -db-sqlite3 -biopython -requests -dendropy \ No newline at end of file +argparse +subprocess +util +opentol +dendropy +biopython \ No newline at end of file diff --git a/workflow/envs/reroot_backbone.txt b/workflow/envs/reroot_backbone.txt index 9d9cb53..91c8719 100644 --- a/workflow/envs/reroot_backbone.txt +++ b/workflow/envs/reroot_backbone.txt @@ -1 +1,3 @@ -dendropy \ No newline at end of file +argparse +dendropy +util \ No newline at end of file diff --git a/workflow/scripts/check_reverse_sequences.py b/workflow/scripts/check_reverse_sequences.py new file mode 100644 index 0000000..b698c2d --- /dev/null +++ b/workflow/scripts/check_reverse_sequences.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +import argparse +import tempfile +import os +import logging +import sys +import concurrent.futures +from functools import partial +from copy import deepcopy +from Bio import SeqIO +from Bio.AlignIO import read as read_alignment +from subprocess import run +from tqdm import tqdm # Add this for progress bars + + +def setup_logger(name, level_str): + """Set up and return a logger with the specified name and level""" + level = getattr(logging, level_str.upper()) + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + +def align_score(record, hmmfile, logger): + """ + Uses a Hidden Markov Model to align a sequence using hmmalign. + Returns the score based on posterior probabilities. + """ + try: + # Make a copy of the record to avoid modifying the original + clean_record = deepcopy(record) + + # Remove dashes from the sequence + clean_record.seq = clean_record.seq.replace('-', '') + + # Use unique prefix for temp files to ensure thread safety + prefix = f"seq_{record.id.replace('|', '_')}_" + + # Open temporary files for the sequence and alignment + with tempfile.NamedTemporaryFile(mode='w+', prefix=prefix) as temp_fasta, \ + tempfile.NamedTemporaryFile(mode='w+', prefix=prefix) as temp_stockholm: + + # Save the cleaned sequence to the temporary file + SeqIO.write(clean_record, temp_fasta.name, 'fasta') + + # Run hmm align, read the aligned sequence + run(['hmmalign', '--trim', '-o', temp_stockholm.name, hmmfile, temp_fasta.name], check=True) + alignment = read_alignment(temp_stockholm.name, "stockholm") + + # Get posterior probability string + quality_string = alignment.column_annotations.get('posterior_probability', '') + + count = 0 + # Count . and * characters + dot_count = quality_string.count('.') + star_count = quality_string.count('*') + + # Give value 0 to . and value 10 to * + count += star_count * 10 + + # Add all numbers + digit_sum = sum(int(char) for char in quality_string if char.isdigit()) + count += digit_sum + + # Calculate average count + average_count = count / len(quality_string) if quality_string else 0 + return average_count, alignment[0] + + except Exception as e: + logger.error(f"Error in align_score for {record.id}: {e}") + return 0, None + +def process_sequence(record, hmmfile, logger): + """ + Process a single sequence, testing both orientations. + Returns the record in the correct orientation and whether it was reversed. + """ + # Score original orientation + count_fwd, alignment_fwd = align_score(record, hmmfile, logger) + logger.debug(f'Forward alignment score {count_fwd} for {record.id}') + + # Make a copy for reverse complement + rev_record = deepcopy(record) + rev_record.seq = rev_record.seq.reverse_complement() + + # Score reverse orientation + count_rev, alignment_rev = align_score(rev_record, hmmfile, logger) + logger.debug(f'Reverse alignment score {count_rev} for {record.id}') + + # Keep the orientation with higher score + if count_fwd >= count_rev: + logger.debug(f'Keeping forward orientation for {record.id}') + return record, False + else: + logger.debug(f'Keeping reverse orientation for {record.id}') + return rev_record, True + +def correct_revcom(hmmfile, sequences, logger, threads=4): + """ + Check each sequence with hmmalign in both orientations + and keep the orientation with the higher score. + Multi-threaded version with progress reporting. + """ + corrected_seqs = [] + reversed_count = 0 + + # Create a partial function for processing sequences + process_func = partial(process_sequence, hmmfile=hmmfile, logger=logger) + + # Process sequences in parallel with progress bar + logger.info(f"Starting parallel processing with {threads} threads") + + # Using ThreadPoolExecutor with tqdm for progress tracking + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: + # Submit all tasks and get futures + futures = [executor.submit(process_func, seq) for seq in sequences] + + # Process results as they complete with progress bar + for i, future in enumerate(tqdm(concurrent.futures.as_completed(futures), + total=len(futures), + desc="Processing sequences")): + record, is_reversed = future.result() + corrected_seqs.append(record) + if is_reversed: + reversed_count += 1 + + # Log progress periodically (e.g., every 1000 sequences) + if (i + 1) % 1000 == 0 or i == 0: + logger.info(f"Processed {i + 1}/{len(sequences)} sequences, {reversed_count} reversed so far") + + logger.info(f'Corrected {reversed_count} reverse complemented sequences out of {len(sequences)}') + return corrected_seqs + +def main(): + parser = argparse.ArgumentParser(description='Check and correct reverse complemented sequences using HMM') + parser.add_argument('fasta', help='Input FASTA file') + parser.add_argument('hmm', help='HMM model file') + parser.add_argument('-o', '--output', help='Output FASTA file (default: corrected_output.fa)', + default="corrected_output.fa") + parser.add_argument('-v', '--verbosity', default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + help='Log level (default: INFO)') + parser.add_argument('-t', '--threads', type=int, default=4, + help='Number of threads to use (default: 4)') + parser.add_argument('-l', '--log', help='Log file (optional, if not specified logs to stdout)') + + args = parser.parse_args() + + # Set up logging to file if specified + logger = setup_logger('check_reverse', args.verbosity) + if args.log: + file_handler = logging.FileHandler(args.log) + file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + logger.addHandler(file_handler) + + # Read sequences + sequences = list(SeqIO.parse(args.fasta, "fasta")) + logger.info(f"Read {len(sequences)} sequences from {args.fasta}") + + # Correct sequences with multiple threads + corrected_sequences = correct_revcom(args.hmm, sequences, logger, args.threads) + + # Write corrected sequences + SeqIO.write(corrected_sequences, args.output, "fasta") + logger.info(f"Wrote {len(corrected_sequences)} corrected sequences to {args.output}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workflow/scripts/check_tree_alignment.py b/workflow/scripts/check_tree_alignment.py new file mode 100644 index 0000000..24c7877 --- /dev/null +++ b/workflow/scripts/check_tree_alignment.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +import sys +import os +import time +import concurrent.futures +import multiprocessing +from Bio import SeqIO +import dendropy +from tqdm import tqdm + +def process_alignment_chunk(chunk_file): + """Process a chunk of the alignment file and return the set of taxa IDs.""" + taxa = set() + try: + with open(chunk_file, "r") as f: + for record in SeqIO.parse(f, "fasta"): + taxa.add(record.id) + except Exception as e: + print(f"Error processing chunk {chunk_file}: {str(e)}") + return taxa + +def process_tree_nodes(nodes): + """Process a subset of tree nodes and return the set of taxa labels.""" + return {node.taxon.label for node in nodes if node.taxon is not None} + +def split_fasta_file(input_file, num_chunks): + """Split a FASTA file into approximately equal chunks.""" + # Count sequences and determine chunk size + seq_count = 0 + with open(input_file, "r") as f: + for line in f: + if line.startswith('>'): + seq_count += 1 + + if seq_count == 0: + return [] + + # Adjust number of chunks if there are fewer sequences than requested chunks + num_chunks = min(num_chunks, seq_count) + seqs_per_chunk = max(1, seq_count // num_chunks) + + # Create chunks + chunk_files = [] + current_chunk = [] + current_count = 0 + current_chunk_idx = 0 + + with open(input_file, "r") as f: + current_seq = [] + for line in f: + if line.startswith('>'): + if current_seq: + current_chunk.extend(current_seq) + current_seq = [] + current_count += 1 + + # Start a new chunk if needed + if current_count >= seqs_per_chunk and current_chunk_idx < num_chunks - 1: + chunk_file = f"{input_file}.chunk_{current_chunk_idx}" + with open(chunk_file, "w") as chunk_out: + chunk_out.writelines(current_chunk) + chunk_files.append(chunk_file) + current_chunk = [] + current_count = 0 + current_chunk_idx += 1 + + current_seq.append(line) + + # Add the last sequence if any + if current_seq: + current_chunk.extend(current_seq) + + # Write the final chunk + if current_chunk: + chunk_file = f"{input_file}.chunk_{current_chunk_idx}" + with open(chunk_file, "w") as chunk_out: + chunk_out.writelines(current_chunk) + chunk_files.append(chunk_file) + + return chunk_files + +def check_tree_alignment_compatibility(alignment_file, tree_file, log_file=None, threads=None): + """ + Check if the taxa in the constraint tree match those in the alignment file using multithreading. + Returns True if they match, False otherwise. + """ + start_time = time.time() + + # Determine number of threads + if threads is None: + threads = multiprocessing.cpu_count() + threads = max(1, min(threads, multiprocessing.cpu_count())) + + # Prepare logging + if log_file: + log = open(log_file, "w") + write_log = lambda msg: log.write(f"{msg}\n") + else: + write_log = lambda msg: print(msg) + + write_log(f"Checking compatibility between tree and alignment...") + write_log(f"Alignment file: {alignment_file}") + write_log(f"Tree file: {tree_file}") + write_log(f"Using {threads} threads") + + # Process the alignment file in parallel + try: + write_log(f"Splitting alignment file into {threads} chunks for parallel processing...") + chunk_files = split_fasta_file(alignment_file, threads) + write_log(f"Created {len(chunk_files)} chunk files") + + # Process chunks in parallel + alignment_taxa = set() + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: + future_to_chunk = {executor.submit(process_alignment_chunk, chunk): chunk for chunk in chunk_files} + for future in tqdm(concurrent.futures.as_completed(future_to_chunk), total=len(chunk_files), desc="Processing alignment chunks"): + chunk = future_to_chunk[future] + try: + chunk_taxa = future.result() + alignment_taxa.update(chunk_taxa) + except Exception as e: + write_log(f"Error processing chunk {chunk}: {str(e)}") + + # Clean up chunk files + for chunk_file in chunk_files: + try: + os.remove(chunk_file) + except: + pass + + write_log(f"Found {len(alignment_taxa)} unique sequence IDs in alignment file.") + except Exception as e: + write_log(f"Error reading alignment file: {str(e)}") + if log_file: + log.close() + return False + + # Read taxa from the tree file + try: + write_log("Parsing tree file...") + tree = dendropy.Tree.get(path=tree_file, schema="newick") + + # Extract taxa with parallel processing + leaf_nodes = list(tree.leaf_nodes()) + write_log(f"Found {len(leaf_nodes)} leaf nodes in tree") + + # Divide leaf nodes into chunks + chunks = [] + chunk_size = max(1, len(leaf_nodes) // threads) + for i in range(0, len(leaf_nodes), chunk_size): + chunks.append(leaf_nodes[i:i+chunk_size]) + + # Process chunks in parallel + tree_taxa = set() + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: + future_to_chunk = {executor.submit(process_tree_nodes, chunk): chunk for chunk in chunks} + for future in tqdm(concurrent.futures.as_completed(future_to_chunk), total=len(chunks), desc="Processing tree nodes"): + try: + chunk_taxa = future.result() + tree_taxa.update(chunk_taxa) + except Exception as e: + write_log(f"Error processing tree nodes: {str(e)}") + + write_log(f"Found {len(tree_taxa)} unique taxa in tree file.") + except Exception as e: + write_log(f"Error reading tree file: {str(e)}") + if log_file: + log.close() + return False + + # Check for mismatches + write_log("Comparing taxa sets...") + alignment_only = alignment_taxa - tree_taxa + tree_only = tree_taxa - alignment_taxa + common_taxa = alignment_taxa.intersection(tree_taxa) + + write_log(f"Summary:") + write_log(f" - Total taxa in alignment: {len(alignment_taxa)}") + write_log(f" - Total taxa in tree: {len(tree_taxa)}") + write_log(f" - Taxa in both: {len(common_taxa)} ({(len(common_taxa)/max(1, len(alignment_taxa.union(tree_taxa))))*100:.1f}%)") + + if alignment_only: + write_log(f" - Taxa in alignment but not in tree: {len(alignment_only)} ({(len(alignment_only)/len(alignment_taxa))*100:.1f}% of alignment)") + sample = list(alignment_only)[:10] + write_log(f" Sample: {', '.join(sample)}{'...' if len(alignment_only) > 10 else ''}") + + # Save full list of alignment-only taxa to a file if log file is specified + if log_file: + alignment_only_file = log_file + ".alignment_only_taxa.txt" + with open(alignment_only_file, "w") as f: + for taxon in sorted(alignment_only): + f.write(f"{taxon}\n") + write_log(f" Full list saved to: {alignment_only_file}") + else: + write_log(f" - All alignment taxa are present in the tree") + + if tree_only: + write_log(f" - Taxa in tree but not in alignment: {len(tree_only)} ({(len(tree_only)/len(tree_taxa))*100:.1f}% of tree)") + sample = list(tree_only)[:10] + write_log(f" Sample: {', '.join(sample)}{'...' if len(tree_only) > 10 else ''}") + + # Save full list of tree-only taxa to a file if log file is specified + if log_file: + tree_only_file = log_file + ".tree_only_taxa.txt" + with open(tree_only_file, "w") as f: + for taxon in sorted(tree_only): + f.write(f"{taxon}\n") + write_log(f" Full list saved to: {tree_only_file}") + else: + write_log(f" - All tree taxa are present in the alignment") + + # Determine compatibility status + if len(common_taxa) == 0: + compatibility_status = "NOT compatible at all" + is_compatible = False + elif len(alignment_only) == 0 and len(tree_only) == 0: + compatibility_status = "FULLY compatible" + is_compatible = True + else: + compatibility_status = "NOT FULLY compatible" + is_compatible = False + + write_log(f"\nTree and alignment are {compatibility_status}.") + + # Print a summary of what needs to be done + if not is_compatible: + write_log("\nRecommendation:") + if compatibility_status == "NOT compatible at all": + write_log(" - The tree and alignment have no taxa in common!") + write_log(" - Check if the taxa naming schemes are different between the files") + write_log(" - Ensure you're using the correct input files") + else: # Partially compatible + if tree_only: + write_log(" - Prune the tree to remove taxa not in the alignment") + write_log(" (Use a tree manipulation tool like 'nw_prune' from Newick Utilities)") + write_log(f" Command example: nw_prune {tree_file} `cat {log_file}.tree_only_taxa.txt` > pruned_tree.tre") + if alignment_only: + write_log(" - Either add missing taxa to the tree or filter the alignment") + write_log(" (Use a sequence filtering tool to keep only sequences with IDs in the tree)") + else: + write_log("\nNo action needed - tree and alignment have identical taxa.") + + elapsed_time = time.time() - start_time + write_log(f"Compatibility check completed in {elapsed_time:.2f} seconds.") + + if log_file: + log.close() + + return is_compatible + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python check_tree_alignment.py [log_file] [threads]") + sys.exit(1) + + alignment_file = sys.argv[1] + tree_file = sys.argv[2] + log_file = sys.argv[3] if len(sys.argv) > 3 else None + threads = int(sys.argv[4]) if len(sys.argv) > 4 else None + + compatible = check_tree_alignment_compatibility(alignment_file, tree_file, log_file, threads) + + # Exit with appropriate code (0 for compatible, 1 for incompatible) + sys.exit(0 if compatible else 1) \ No newline at end of file diff --git a/workflow/scripts/create_database.py b/workflow/scripts/create_database.py index 3ed7e12..cc00d56 100644 --- a/workflow/scripts/create_database.py +++ b/workflow/scripts/create_database.py @@ -1,3 +1,4 @@ +import os.path import sqlite3 import util import argparse @@ -126,7 +127,7 @@ class TEXT NOT NULL, family TEXT NOT NULL, subfamily TEXT NOT NULL, genus TEXT NOT NULL, - species TEXT NOT NULL, + species TEXT NOT NULL, bin_uri TEXT NOT NULL, opentol_id INTEGER, UNIQUE(kingdom, phylum, class, "order", family, subfamily, genus, species, bin_uri)); @@ -200,10 +201,17 @@ def update_fk(): # Instantiate logger logger = util.get_formatted_logger('create_database', args.verbosity) + if os.path.isfile(args.outdb): + os.unlink(args.outdb) + # Connect to the database logger.info('Going to connect to database') connection = sqlite3.connect(args.outdb) database_cursor = connection.cursor() + database_cursor.execute('pragma journal_mode=OFF') + database_cursor.execute('PRAGMA synchronous=OFF') + database_cursor.execute('PRAGMA cache_size=100000') + database_cursor.execute('PRAGMA temp_store = MEMORY') # Create database tables create_taxon_table('taxon') diff --git a/workflow/scripts/family_constraint.py b/workflow/scripts/family_constraint.py index f1225a6..860f715 100644 --- a/workflow/scripts/family_constraint.py +++ b/workflow/scripts/family_constraint.py @@ -32,34 +32,47 @@ def extract_id_from_fasta(unaligned, outgroups): """ ids = [] - # process the ingroup file + # Process the ingroup file with open(unaligned, 'r') as file: for line in file: if line.startswith('>'): parts = line.strip().split('|') if len(parts) > 1: - # Extract the second element and remove 'ott' prefix id_with_prefix = parts[1] - id_number = id_with_prefix.replace('ott', '') - if id_number != 'None': + if id_with_prefix != "ottNone": + id_number = id_with_prefix.replace('ott', '') ids.append(int(id_number)) - # process the outgroup file + # Process the outgroup file with open(outgroups, 'r') as file: for line in file: if line.startswith('>'): pid = line.strip().removeprefix('>') sql = f"SELECT t.opentol_id FROM taxon t, barcode b WHERE t.taxon_id=b.taxon_id and b.processid='{pid}'" ott = conn.execute(sql).fetchone() - ids.append(ott[0]) + logger.info(f"Result for processid {pid} was {ott}") + if ott: + ids.append(ott[0]) - # Remove all 'None' entries - cleaned_list = [item for item in ids if item != 'None'] - return cleaned_list + return ids + + +def test_database_connection(database): + """ + Tests the connection to the SQLite database. + :param database: the location of the SQLite database file + :return: True if connection is successful, False otherwise + """ + try: + conn = sqlite3.connect(database) + conn.close() + return True + except sqlite3.Error as e: + print(f"Database connection failed: {e}") + return False if __name__ == '__main__': - # Define command line arguments parser = argparse.ArgumentParser(description='Required command line arguments.') parser.add_argument('-i', '--ingroup', required=True, help='FASTA file with the ingroup') parser.add_argument('-g', '--outgroups', required=True, help='FASTA file with outgroup taxa') @@ -68,29 +81,46 @@ def extract_id_from_fasta(unaligned, outgroups): parser.add_argument('-v', '--verbosity', required=True, help='Log level (e.g. DEBUG)') args = parser.parse_args() - # Configure logging logger = util.get_formatted_logger('family_constraint', args.verbosity) + logger.info(f"Connecting to database {args.database}") + + if test_database_connection(args.database): + logger.info("Database connection successful.") + else: + logger.error("Database connection failed.") + exit(1) - # Connect to the database (creates a new file if it doesn't exist) - logger.info(f"Going to connect to database {args.database}") conn = sqlite3.connect(args.database) - # Read input alignment, get ott IDs ott_ids = extract_id_from_fasta(args.ingroup, args.outgroups) - # If we have no IDs at all, we write a zero byte file for run_raxml if len(ott_ids) == 0: - logger.warning('There were zero OTT IDs in the input file') - with open(args.outtree, "a"): + logger.warning('No valid OTT IDs found in the input files.') + with open(args.outtree, "a"): # Create an empty output file pass else: - # Get subtree from OpenToL WS API - tree = opentol.get_subtree(ott_ids) - - # Write output - logger.info(f'Going to write tree to {args.outtree}') - with open(args.outtree, "w") as output_file: - output_file.write(tree.as_string(schema="newick")) - - + try: + # Try to fetch the subtree using the OpenToL API + logger.info("Going to query OpenToL API") + tree = opentol.get_subtree(ott_ids) + logger.info(f"Done querying, got {tree}") + if tree is None: + logger.error('The API returned None, indicating no tree could be generated.') + with open(args.outtree, "a"): # Create an empty file + pass + else: + # Confirm tree has correct format before writing + if hasattr(tree, 'as_string'): + logger.info(f'Writing tree to {args.outtree}') + with open(args.outtree, "w") as output_file: + output_file.write(tree.as_string(schema="newick")) + else: + logger.error('The tree object does not support "as_string"; skipping file writing.') + with open(args.outtree, "a"): + pass + + except Exception as e: + logger.error(f"Failed to fetch or write subtree: {e}") + with open(args.outtree, "a"): # Create an empty file on failure + pass diff --git a/workflow/scripts/family_fasta.py b/workflow/scripts/family_fasta.py index c567e57..a1d853c 100644 --- a/workflow/scripts/family_fasta.py +++ b/workflow/scripts/family_fasta.py @@ -1,10 +1,12 @@ -import errno import sqlite3 import os import pandas as pd import argparse import util +import re +from pathlib import Path +levels = ['kingdom', 'phylum', 'class', 'order', 'family', 'subfamily', 'genus', 'all'] """ This script, `family_fasta.py`, is responsible for generating FASTA files for each family of a specified higher taxon @@ -26,6 +28,14 @@ arguments in the rule `family_fasta`. """ +def sanitize_string(input_string): + """ + Replaces all characters in the input string that do not match [a-zA-Z0-9_-] with an underscore (_). + :param input_string: The string to be sanitized + :return: The sanitized string + """ + return re.sub(r'[^a-zA-Z0-9_-]', '_', input_string) + def get_family_bins(q, conn): """ Gets distinct families and bins for the higher taxon defined in the query restrictions @@ -35,14 +45,14 @@ def get_family_bins(q, conn): """ # Check if filter_level in config.yaml is usable - if q['level'].lower() in ['kingdom', 'phylum', 'class', 'order', 'family', 'subfamily', 'genus', 'all']: + if q['level'].lower() in levels: # Select all distinct family names that match config.yaml filters level = q['level'] name = q['name'] marker_code = q['marker_code'] sql = f''' - SELECT DISTINCT family, bin_uri + SELECT DISTINCT family, genus, species, bin_uri FROM barcode WHERE marker_code = '{marker_code}' AND "{level}" = '{name}' @@ -61,7 +71,7 @@ def get_family_bins(q, conn): return fam -def write_bin(q, conn, fh): +def write_bin(q, conn, outfile): """ Writes the longest sequence for a BIN to file :param q: query object @@ -78,7 +88,7 @@ def write_bin(q, conn, fh): JOIN taxon t ON b.taxon_id = t.taxon_id WHERE t."{q["level"]}" = "{q["name"]}" AND - t.family = "{q["family"]}" AND + t."{q["rank"]}" = "{q["taxon"]}" AND t.bin_uri = "{q["bin_uri"]}" AND b.marker_code = "{q["marker_code"]}" AND t.species IS NOT NULL AND @@ -90,13 +100,14 @@ def write_bin(q, conn, fh): famseq = pd.read_sql_query(query, conn) # Append to file handle fh - for _, row in famseq.iterrows(): - defline = f'>{row["barcode_id"]}|ott{row["opentol_id"]}|{row["processid"]}|{row["bin_uri"]}|{row["species"]}\n' - fh.write(defline) + with open(outfile, "a") as fh: + for _, row in famseq.iterrows(): + defline = f'>{row["barcode_id"]}|ott{row["opentol_id"]}|{row["processid"]}|{row["bin_uri"]}|{row["species"]}\n' + fh.write(defline) - # Strip non-ACGT characters (dashes, esp.) because hmmer chokes on them - seq = row['nuc'].replace('-', '') + '\n' - fh.write(seq) + # Strip non-ACGT characters (dashes, esp.) because hmmer chokes on them + seq = row['nuc'].replace('-', '') + '\n' + fh.write(seq) if __name__ == '__main__': @@ -106,15 +117,28 @@ def write_bin(q, conn, fh): parser.add_argument('-f', '--fasta_dir', required=True, help='Directory to write FASTA files to') parser.add_argument('-l', '--level', required=True, help='Taxonomic level to filter (e.g. order)') parser.add_argument('-n', '--name', required=True, help='Taxon name to filter (e.g. Primates)') - parser.add_argument('-c', '--chunks', required=True, help="Number of chunks (families) to write to file") + parser.add_argument('-L', '--limit', required=True, type=int, help='Fasta sequence limit, switch to lower rank if above (e.g. 200)') parser.add_argument('-m', '--marker', required=True, help='Marker code, e.g. COI-5P') parser.add_argument('-v', '--verbosity', required=True, help='Log level (e.g. DEBUG)') args = parser.parse_args() database_file = args.database + level = args.level.lower() + if level not in levels: + raise Exception(f"Filter level {level} from config file does not exists as a column in the database") + + if levels.index(level) > levels.index("family"): + raise Exception("Level filter value must be 'family' or higher rank") + # Configure logger logger = util.get_formatted_logger('family_fasta', args.verbosity) + try: + os.makedirs(args.fasta_dir, exist_ok=True) + except OSError as error: + logger.error(error) + exit(1) + # Connect to the database (creates a new file if it doesn't exist) logger.info(f"Going to connect to database {args.database}") connection = sqlite3.connect(args.database) @@ -128,29 +152,75 @@ def write_bin(q, conn, fh): } df = get_family_bins(query, connection) - # Iterate over distinct families - index = 1 - for family in df['family'].unique(): - logger.info(f"Writing {family}") + def write_fasta(query, rank, taxon, family_bin_uris): + logger.info(f"Writing {taxon} ({rank})") + + # Replace "/" with "_" in taxon name + taxon = sanitize_string(taxon) # Make directory and open file handle - subdir = os.path.join(args.fasta_dir, f"{index}-of-{args.chunks}") + align_file = os.path.join(args.fasta_dir, "taxon", taxon, "unaligned.fa") try: - os.mkdir(subdir) + os.makedirs(Path(align_file).parent, exist_ok=True) except OSError as error: - logger.warning(error) - - with open(os.path.join(subdir, 'unaligned.fa'), 'w') as handle: - - # Iterate over bins in family + logger.error(error) + exit(1) + + # Iterate over bins in family + for bin_uri in family_bin_uris: + logger.debug(f"Writing {bin_uri}") + query['bin_uri'] = bin_uri + query["rank"] = rank + query["taxon"] = taxon + write_bin(query, connection, align_file) + + family_set = {} + genus_set = {} + # Iterate over distinct families + with open(os.path.join(args.fasta_dir, "taxon_fasta.tsv"), "w") as fw: + unique_families = df['family'].unique() + split_families = [] + for family in unique_families: family_bin_uris = df[df['family'] == family]['bin_uri'].unique() - for bin_uri in family_bin_uris: - logger.debug(f"Writing {bin_uri}") - query['bin_uri'] = bin_uri - query['family'] = family - write_bin(query, connection, handle) - - index += 1 + if len(family_bin_uris) > args.limit: + split_families.append(family) + fw.write(f"{family}\tfamily\t{len(family_bin_uris)}\tTrue\t\n") + continue + + write_fasta(query, "family", family, family_bin_uris) + fw.write(f"{family}\tfamily\t{len(family_bin_uris)}\tFalse\t\n") + + split_genera = [] + if split_families: + for family in split_families: + unique_genera = df[df['family'] == family]['genus'].unique() + for genus in unique_genera: + if not genus: + continue + genus_bin_uris = df[(df['family'] == family) & (df['genus'] == genus)]['bin_uri'].unique() + if len(genus_bin_uris) > args.limit: + # split_genera.append(genus) + # fw.write(f"{genus}\tgenus\t{len(genus_bin_uris)}\tTrue\t{family}\n") + logger.warning(f"Genus {genus} in family {family} ({len(genus_bin_uris)}) exceeds the limit of {args.limit}.") + # continue + + write_fasta(query, "genus", genus, genus_bin_uris) + fw.write(f"{genus}\tgenus\t{len(genus_bin_uris)}\tFalse\t{family}\n") + + # if split_genera: + # for genus in split_genera: + # if not genus: + # continue + # unique_species = df[df['genus'] == genus]['species'].unique() + # for species in unique_species: + # if not species: + # continue + # species_bin_uris = df[(df['genus'] == genus) & (df['species'] == species)]['bin_uri'].unique() + # if len(species_bin_uris) > args.limit: + # logger.error(f"Species {species} in genus {genus} exceeds the limit of {args.limit}. Skipping.") + # continue + # write_fasta(query, "species", species, species_bin_uris) + # fw.write(f"{species}\tspecies\t{len(species_bin_uris)}\tFalse\t{genus}\n") # Close the connection - connection.close() + connection.close() \ No newline at end of file diff --git a/workflow/scripts/filter_bcdm.py b/workflow/scripts/filter_bcdm.py new file mode 100644 index 0000000..6d5d786 --- /dev/null +++ b/workflow/scripts/filter_bcdm.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 + +import sys +import re +import os +import gzip +import time +import multiprocessing +from tqdm import tqdm +from functools import partial +from concurrent.futures import ProcessPoolExecutor, as_completed + +def read_process_ids_from_phylip(phylip_file): + """Extract process IDs from a Phylip file.""" + process_ids = set() + with open(phylip_file, 'r') as f: + # Skip the first line (contains dimensions) + next(f) + for line in f: + # In Phylip format, the sequence name comes before the first space + match = re.match(r'^(\S+)', line) + if match: + process_id = match.group(1) + process_ids.add(process_id) + return process_ids + +def identify_process_id_column(header): + """Identify the process ID column in the BCDM file header.""" + columns = header.strip().split('\t') + + # Check for common process ID column names + for i, col in enumerate(columns): + col_lower = col.lower() + if 'processid' in col_lower or 'process_id' in col_lower: + return i + + # If not found, check for other possible ID columns + for i, col in enumerate(columns): + col_lower = col.lower() + if col_lower == 'id' or 'sequenceid' in col_lower: + return i + + # Return None if no suitable column found + return None + +def get_file_line_count(file_path): + """Count the number of lines in a file efficiently.""" + is_gzipped = file_path.endswith('.gz') + open_func = gzip.open if is_gzipped else open + mode = 'rt' if is_gzipped else 'r' + + print(f"Counting lines in {file_path}...") + line_count = 0 + chunk_size = 1024 * 1024 # 1MB chunks + + with open_func(file_path, mode) as file: + # Skip header + next(file) + line_count = 1 # Start with 1 for the header + + with tqdm(unit='MB', desc="Counting lines") as pbar: + while True: + chunk = file.read(chunk_size) + if not chunk: + break + line_count += chunk.count('\n') + pbar.update(1) + + return line_count + +def split_file_into_chunks(file_path, num_chunks): + """Return file offsets for chunks.""" + is_gzipped = file_path.endswith('.gz') + if is_gzipped: + print("Warning: Multithreaded processing is not as efficient with gzipped files") + # For gzipped files, we'll return just one chunk for simplicity + return [(0, -1)] + + total_size = os.path.getsize(file_path) + chunk_size = total_size // num_chunks + + offsets = [] + with open(file_path, 'rb') as f: + # Read header + header = f.readline() + header_offset = len(header) + + # First chunk starts after header + start_offset = header_offset + + for i in range(num_chunks - 1): + # Jump to approximate chunk boundary + f.seek(start_offset + chunk_size) + + # Find the next newline + while f.read(1) != b'\n': + pass + + # Record end offset for current chunk + end_offset = f.tell() + offsets.append((start_offset, end_offset)) + + # Set start of next chunk + start_offset = end_offset + + # Last chunk goes to end of file + offsets.append((start_offset, -1)) + + return offsets + +def process_chunk(file_path, offset_range, process_ids, process_id_col, results_queue=None): + """Process a chunk of the BCDM file.""" + start_offset, end_offset = offset_range + is_gzipped = file_path.endswith('.gz') + open_func = gzip.open if is_gzipped else open + mode = 'rt' if is_gzipped else 'r' + + matching_records = [] + count = 0 + + with open_func(file_path, mode) as f: + # If not the first chunk, get header first + if start_offset > 0: + header = f.readline() # Read and discard header + f.seek(start_offset) + + # Read until end_offset or EOF + while True: + if end_offset > 0 and f.tell() >= end_offset: + break + + line = f.readline() + if not line: + break + + count += 1 + + fields = line.strip().split('\t') + if len(fields) > process_id_col: + pid = fields[process_id_col] + if pid in process_ids: + matching_records.append(line) + + if results_queue: + results_queue.put((matching_records, count)) + return matching_records, count + +def filter_bcdm_file(bcdm_file, process_ids, output_file, threads=None): + """Filter the BCDM file to only include rows with matching process IDs.""" + if threads is None: + threads = max(1, multiprocessing.cpu_count() - 1) + + # Determine if the file is gzipped + is_gzipped = bcdm_file.endswith('.gz') + open_func = gzip.open if is_gzipped else open + mode = 'rt' if is_gzipped else 'r' + + start_time = time.time() + print(f"Starting to filter {bcdm_file} using {threads} threads") + + # Read the header to identify the process ID column + with open_func(bcdm_file, mode) as f: + header = f.readline() + + process_id_col = identify_process_id_column(header) + if process_id_col is None: + print("Error: Could not identify process ID column in BCDM file") + return False + + print(f"Using column {process_id_col} for process IDs") + + # Get approximate number of records for progress tracking + total_lines = get_file_line_count(bcdm_file) + print(f"File has approximately {total_lines:,} records (including header)") + + # Split file into chunks + chunks = split_file_into_chunks(bcdm_file, threads) + print(f"Split file into {len(chunks)} processing chunks") + + # Process chunks in parallel + matching_records = [] + total_processed = 0 + + # Setup progress bar + pbar = tqdm(total=total_lines-1, desc="Filtering", unit="records") + + # Create a manager and queue for progress updates + manager = multiprocessing.Manager() + results_queue = manager.Queue() + + # Process chunks in parallel + chunk_processors = [] + with ProcessPoolExecutor(max_workers=threads) as executor: + for chunk in chunks: + processor = executor.submit( + process_chunk, bcdm_file, chunk, process_ids, process_id_col, results_queue + ) + chunk_processors.append(processor) + + # Monitor the queue and update progress + all_matches = [] + all_counts = 0 + completed = 0 + + while completed < len(chunks): + if not results_queue.empty(): + matches, count = results_queue.get() + all_matches.extend(matches) + all_counts += count + pbar.update(count) + completed += 1 + else: + time.sleep(0.1) + + pbar.close() + + # Write results + with open(output_file, 'w') as outfile: + outfile.write(header) + for record in all_matches: + outfile.write(record) + + elapsed = time.time() - start_time + print(f"Completed filtering in {elapsed:.2f} seconds") + print(f"Processed {all_counts:,} records, found {len(all_matches):,} matches") + + return True + +def main(): + if len(sys.argv) < 4: + print("Usage: python filter_bcdm.py [threads]") + sys.exit(1) + + phylip_file = sys.argv[1] + bcdm_file = sys.argv[2] + output_file = sys.argv[3] + threads = int(sys.argv[4]) if len(sys.argv) > 4 else None + + print(f"Reading process IDs from {phylip_file}...") + process_ids = read_process_ids_from_phylip(phylip_file) + print(f"Found {len(process_ids):,} unique process IDs") + + success = filter_bcdm_file(bcdm_file, process_ids, output_file, threads) + + if success: + print(f"Filtered BCDM saved to {output_file}") + else: + print("Failed to filter BCDM file") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workflow/scripts/find_polytomies.py b/workflow/scripts/find_polytomies.py new file mode 100644 index 0000000..f0aecbd --- /dev/null +++ b/workflow/scripts/find_polytomies.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +Find Polytomies in Phylogenetic Trees + +This script identifies polytomies (nodes with >2 children) in a phylogenetic tree, +often used as a preprocessing step for tree refinement. Polytomies are identified +and filtered based on their depth in the tree and number of descendant tips. + +Input: + - Newick format phylogenetic tree file + +Output: + - Text file with a list of polytomies, including: + * Tips contained in each polytomy + * Depth of the polytomy in the tree + * Size (number of tips) + * Taxonomic name associated with the polytomy + +Usage: + python find_polytomies.py -t tree_file.tre [options] > polytomies.txt + +Dependencies: + - ETE3 (Python package for tree manipulation) +""" + +import argparse +from ete3 import Tree +import re + +def get_node_depth(node): + """ + Get the depth of a node in the tree (distance from root) + + Args: + node: An ETE3 TreeNode object + + Returns: + int: Depth of the node (0 for root, increases with distance from root) + """ + depth = 0 + current = node + while current.up: + depth += 1 + current = current.up + return depth + +def find_taxonomic_name(node): + """ + Extract taxonomic name from node or its parent nodes + + This function tries to find a taxonomic name (not formatted as a barcode ID) + associated with the node or its ancestors. + + Args: + node: An ETE3 TreeNode object + + Returns: + str: Taxonomic name if found, None otherwise + """ + if node.name and node.name.strip() and not re.match(r'^[A-Z]+\d+-\d+$', node.name): + return node.name + + # If no name at this node, try parent nodes until we find a taxonomic name + parent = node.up + while parent: + if parent.name and parent.name.strip() and not re.match(r'^[A-Z]+\d+-\d+$', parent.name): + return parent.name + parent = parent.up + + return None + +def find_polytomies(tree_file, min_depth=0, max_depth=None, min_size=3, max_size=None): + """ + Identify and report polytomies in a phylogenetic tree + + Args: + tree_file: Path to Newick format tree file + min_depth: Minimum depth to consider (filter shallower polytomies) + max_depth: Maximum depth to consider (filter deeper polytomies) + min_size: Minimum number of tips in a polytomy + max_size: Maximum number of tips in a polytomy + + Returns: + None: Results are printed to stdout + """ + # Load tree from the input file + tree = Tree(tree_file, format=1) # Adjust format if needed + + # Find polytomies (nodes with more than 2 children) + all_polytomies = [node for node in tree.traverse() if len(node.children) > 2] + + # Filter polytomies by depth and size + filtered_polytomies = [] + for node in all_polytomies: + depth = get_node_depth(node) + size = len(node.get_leaf_names()) + taxon = find_taxonomic_name(node) + + if (min_depth <= depth and (max_depth is None or depth <= max_depth) and + size >= min_size and (max_size is None or size <= max_size)): + filtered_polytomies.append((node, depth, size, taxon)) + + print(f"Found {len(filtered_polytomies)} polytomies (filtered from {len(all_polytomies)} total)") + + # Extract tip labels for each polytomy + for i, (node, depth, size, taxon) in enumerate(filtered_polytomies): + tips = node.get_leaf_names() + print(f"Polytomy {i+1}: {tips} (depth={depth}, size={size}, taxon={taxon})") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Find polytomies in a constraint tree.") + parser.add_argument("-t", "--tree", required=True, + help="Path to the constraint tree file (Newick format)") + parser.add_argument("--min-depth", type=int, default=0, + help="Minimum depth for polytomies to include (default: 0)") + parser.add_argument("--max-depth", type=int, default=None, + help="Maximum depth for polytomies to include (default: None)") + parser.add_argument("--min-size", type=int, default=3, + help="Minimum number of tips in polytomies (default: 3)") + parser.add_argument("--max-size", type=int, default=None, + help="Maximum number of tips in polytomies (default: None)") + + args = parser.parse_args() + + find_polytomies(args.tree, args.min_depth, args.max_depth, args.min_size, args.max_size) + +""" +Example usage: + +# Find all polytomies with at least 3 tips +python find_polytomies.py -t constraint_tree.tre > all_polytomies.txt + +# Find polytomies with 3-50 tips at taxonomic level depth 5 or deeper +python find_polytomies.py -t constraint_tree.tre --min-depth 5 --min-size 3 --max-size 50 > filtered_polytomies.txt + +""" \ No newline at end of file diff --git a/workflow/scripts/format_constraint_tree.py b/workflow/scripts/format_constraint_tree.py new file mode 100644 index 0000000..9487952 --- /dev/null +++ b/workflow/scripts/format_constraint_tree.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +""" +Format Constraint Tree with OTT IDs and Bifurcation + +Converts a taxonomic-labeled constraint tree to a fully bifurcating tree +with OTT IDs at internal nodes and barcode IDs at tips. + +Input: + - Constraint tree with taxonomic names at internal nodes + +Output: + - Fully bifurcating constraint tree with OTT IDs at internal nodes and no branch lengths +""" + +import argparse +import re +from ete3 import Tree +import logging +import requests +import os +import sys + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger('format_constraint_tree') + +# Cache for OTT IDs to avoid redundant API calls +ott_cache = {} + +def get_ott_id(taxon_name): + """Get OTT ID for a taxonomic name using OpenTree API""" + if taxon_name in ott_cache: + return ott_cache[taxon_name] + + # Clean up taxon name - replace underscores with spaces + clean_name = taxon_name.replace("_", " ") + + url = "https://api.opentreeoflife.org/v3/tnrs/match_names" + payload = { + "names": [clean_name], + "do_approximate_matching": True + } + + try: + logger.debug(f"Looking up OTT ID for '{clean_name}'") + response = requests.post(url, json=payload) + data = response.json() + + if 'results' in data and len(data['results']) > 0 and len(data['results'][0]['matches']) > 0: + ott_id = data['results'][0]['matches'][0]['taxon']['ott_id'] + ott_cache[taxon_name] = f"ott{ott_id}" + return f"ott{ott_id}" + else: + logger.warning(f"No OTT ID found for '{clean_name}'") + ott_cache[taxon_name] = None + return None + except Exception as e: + logger.error(f"Error querying OTT ID: {e}") + return None + +def is_barcode_id(name): + """Check if a name looks like a barcode ID""" + # Most barcode IDs follow patterns like ABCDE123-45 + return bool(re.match(r'^[A-Z]+\d+-\d+$', name)) + +def format_tree(input_tree, output_tree, force_bifurcate=True): + """ + Format tree by replacing taxonomic names with OTT IDs and ensuring bifurcation + + Args: + input_tree: Path to input tree file + output_tree: Path to output tree file + force_bifurcate: Whether to force the tree to be fully bifurcating + """ + # Load tree + tree = Tree(input_tree, format=1) + logger.info(f"Loaded tree with {len(tree)} tips") + + # Process internal nodes + internal_nodes = 0 + resolved_nodes = 0 + unnamed_count = 0 + + # First pass: get OTT IDs for all taxonomic names + for node in tree.traverse(): + # Skip tips - keep their barcode IDs + if node.is_leaf(): + continue + + if node.name and node.name.strip(): + internal_nodes += 1 + + # Skip nodes that already have OTT format + if node.name.startswith('ott'): + resolved_nodes += 1 + continue + + # Get OTT ID for this taxonomic name + ott_id = get_ott_id(node.name) + if ott_id: + node.name = ott_id + resolved_nodes += 1 + else: + # If no OTT ID found, create a placeholder + unnamed_count += 1 + node.name = f"unnamed{unnamed_count}" + + # Second pass: force bifurcation if requested + if force_bifurcate: + polytomies = 0 + resolved_polytomies = 0 + + # Find all polytomies + for node in tree.traverse(): + if len(node.children) > 2: + polytomies += 1 + + # Sort children by the number of descendants (smallest first) + node.children.sort(key=lambda n: len(n.get_leaves())) + + # Create a ladder-like structure with the sorted children + while len(node.children) > 2: + # Take the two smallest children + child1 = node.children[0] + child2 = node.children[1] + + # Remove them from the node + node.remove_child(child1) + node.remove_child(child2) + + # Create a new internal node to hold them + new_node = Tree() + new_node.name = f"unnamed{unnamed_count}" + unnamed_count += 1 + + # Add the two children to the new node + new_node.add_child(child1) + new_node.add_child(child2) + + # Add the new node back to the original node + node.add_child(new_node) + + resolved_polytomies += 1 + + logger.info(f"Resolved {resolved_polytomies} of {polytomies} polytomies") + + # Remove branch lengths + for node in tree.traverse(): + node.dist = 0 + + # Write the tree in Newick format + newick = tree.write(format=9) # format 9 is Newick with internal node names + + # Write to file + with open(output_tree, 'w') as f: + f.write(newick) + + # Verify the tree structure + bifurcating = all(len(node.children) <= 2 for node in tree.traverse() if not node.is_leaf()) + + logger.info(f"Processed {internal_nodes} internal nodes") + logger.info(f"Found OTT IDs for {resolved_nodes} nodes ({resolved_nodes/internal_nodes*100:.1f}% if internal_nodes else 0)") + logger.info(f"Created {unnamed_count} unnamed internal nodes") + logger.info(f"Tree is{'fully' if bifurcating else 'NOT'} bifurcating") + logger.info(f"Wrote formatted tree to {output_tree}") + +def main(): + parser = argparse.ArgumentParser(description="Format constraint tree with OTT IDs and ensure bifurcation") + parser.add_argument("-i", "--input", required=True, + help="Path to input tree file") + parser.add_argument("-o", "--output", required=True, + help="Path to output tree file") + parser.add_argument("--no-bifurcate", action="store_true", + help="Don't force bifurcation (keep polytomies)") + parser.add_argument("--batch-size", type=int, default=100, + help="Batch size for OTT ID lookups (default: 100)") + parser.add_argument("-v", "--verbose", action="store_true", + help="Enable verbose output") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + format_tree(args.input, args.output, not args.no_bifurcate) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workflow/scripts/graft_clades.py b/workflow/scripts/graft_clades.py index dda2df1..275cfab 100644 --- a/workflow/scripts/graft_clades.py +++ b/workflow/scripts/graft_clades.py @@ -53,7 +53,6 @@ def read_tree(filename, rooting='default-rooted', schema='newick'): parser.add_argument('-f', '--folder', required=True, help='Location of folder with subtree folders') parser.add_argument('-e', '--extinct', required=True, help='File with extinct PIDs to skip') parser.add_argument('-o', '--out', required=True, help="Output grafted newick") - parser.add_argument('-n', '--nfamilies', required=True, help='Number of families') parser.add_argument('-v', '--verbosity', required=True, help='Log level (e.g. DEBUG)') args = parser.parse_args() @@ -61,11 +60,11 @@ def read_tree(filename, rooting='default-rooted', schema='newick'): logger = util.get_formatted_logger('graft_clades', args.verbosity) # Read the extinct PIDs - extinct = [] + extinct = set() with open(args.extinct, 'r') as file: for line in file: clean_line = line.strip() - extinct.append(clean_line) + extinct.add(clean_line) logger.info(f"extinct: {extinct}") # Read the backbone tree as a dendropy object, calculate distances to root, and get its leaves backbone = read_tree(args.tree) @@ -74,15 +73,16 @@ def read_tree(filename, rooting='default-rooted', schema='newick'): # Iterate over folders base_folder = os.path.abspath(args.folder) - for i in range(1, int(args.nfamilies) + 1): - logger.info(f'Processing subtree {i}') + for taxon in os.listdir(base_folder): + logger.info(f'Processing subtree {taxon}') + subfolder = os.path.join(base_folder, taxon) - # Peprocess the focal family tree - subfolder = f'{i}-of-{args.nfamilies}' + # Preprocess the focal family tree subtree_file = os.path.join(base_folder, subfolder, 'aligned.fa.raxml.bestTree.rooted') try: subtree = read_tree(subtree_file) except: + logger.info(f'warning: subtree for taxon "{taxon}" could not be read') continue subtree.calc_node_root_distances() diff --git a/workflow/scripts/makeblastdb.py b/workflow/scripts/makeblastdb.py new file mode 100644 index 0000000..ad72dc8 --- /dev/null +++ b/workflow/scripts/makeblastdb.py @@ -0,0 +1,47 @@ +import os +import subprocess +import argparse +import re + +def process_taxon_files(fasta_dir, tmp_file): + i = 0 + with open(tmp_file, 'w') as tmp: + for taxon in os.listdir(fasta_dir): + taxon_path = os.path.join(fasta_dir, taxon) + if os.path.isdir(taxon_path): + i += 1 + infile = os.path.join(taxon_path, 'unaligned.fa') + + with open(infile, 'r') as f: + lines = f.readlines() + + filtered_lines = [] + for j in range(0, len(lines), 2): + parts = lines[j].split('|') + if len(parts) > 1 and re.match(r'ott\d+', parts[1].strip()): + header = parts[2] + sequence = lines[j + 1] + filtered_lines.append(f'>{header.strip()}\n{sequence}') + + if i == 1: + tmp.writelines(filtered_lines) + else: + with open(tmp_file, 'a') as tmp_append: + tmp_append.writelines(filtered_lines) + +def make_blast_db(tmp_file, database): + subprocess.run(['makeblastdb', '-in', tmp_file, '-dbtype', 'nucl', '-out', database, '-parse_seqids']) + +def main(): + parser = argparse.ArgumentParser(description='Create BLAST database from FASTA files.') + parser.add_argument('-f', '--fasta_dir', required=True, help='Directory containing FASTA files') + parser.add_argument('-t', '--tmp_file', required=True, help='Temporary file to store filtered sequences') + parser.add_argument('-d', '--database', required=True, help='Output BLAST database file') + args = parser.parse_args() + + process_taxon_files(args.fasta_dir, args.tmp_file) + make_blast_db(args.tmp_file, args.database) + os.remove(args.tmp_file) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/workflow/scripts/makeblastdb.sh b/workflow/scripts/makeblastdb.sh index 3267d49..e98fb82 100644 --- a/workflow/scripts/makeblastdb.sh +++ b/workflow/scripts/makeblastdb.sh @@ -1,7 +1,6 @@ DATABASE=$1 FASTADIR=$2 -ITERATIONS=$3 -TMP=$4 +TMP=$3 # This shell script, `makeblastdb.sh`, is responsible for creating a BLAST database from FASTA files generated by # `family_fasta.py`. The script performs the following steps: @@ -20,10 +19,10 @@ TMP=$4 # as a shell command in the rule `makeblastdb`. # doing this sequentially to avoid race conditions in BLAST indexing -for i in $(seq 1 $ITERATIONS); do - - # using the output from family_fasta - INFILE=${FASTADIR}/${i}-of-${ITERATIONS}/unaligned.fa +i=0 +for TAXON in $(ls -d $FASTADIR/taxon/*); do + i=$((i + 1)) + INFILE=${TAXON}/unaligned.fa # only keep records with ott IDs, reformat the headers to retain the process ID, write to $TMP # start new file on first iteration, then append diff --git a/workflow/scripts/map_opentol.py b/workflow/scripts/map_opentol.py index 9847601..317d499 100644 --- a/workflow/scripts/map_opentol.py +++ b/workflow/scripts/map_opentol.py @@ -89,6 +89,11 @@ def match_opentol(kingdom, chunksize, fuzzy): :return: """ + # NOTE: this query takes up a lot of memory/swap space. This is because we are chunking through the database + # doing exact matches for 10k records at a time, and those records are loaded in memory in full (e.g. including + # the sequences and all other columns) by doing 'SELECT *'. Instead, we can get away with doing 'SELECT species, taxon_id' + # so that the chunks have a smaller footprint. + # Load all unmatched records into df, iterate over it in chunks df = pd.read_sql("SELECT * FROM taxon WHERE opentol_id IS NULL", conn) for _, chunk_df in df.groupby(np.arange(len(df)) // chunksize): @@ -133,6 +138,10 @@ def postprocess_db(): logger.info(f'Going to connect to database {args.database}') conn = sqlite3.connect(args.database) cursor = conn.cursor() + cursor.execute('pragma journal_mode=OFF') + cursor.execute('PRAGMA synchronous=OFF') + cursor.execute('PRAGMA cache_size=100000') + cursor.execute('PRAGMA temp_store = MEMORY') # Infer taxonomic context from marker name if args.marker == "COI-5P": diff --git a/workflow/scripts/multi_pass_resolve.py b/workflow/scripts/multi_pass_resolve.py new file mode 100644 index 0000000..bcc6806 --- /dev/null +++ b/workflow/scripts/multi_pass_resolve.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +""" +Multi-Pass Polytomy Resolution with Multithreading + +Resolves polytomies in multiple passes with optimized parameters using parallel processing. +Includes rate limiting to prevent API throttling. +""" + +import subprocess +import argparse +import logging +import os +import re +import tempfile +import sys +import multiprocessing +import time +import signal +import random +from functools import partial + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger('multi_pass_resolve') + +def count_polytomies(tree_file): + """Count polytomies in the tree""" + # Create a temporary file to store polytomy info + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp: + temp_file = temp.name + + try: + # Get script path - handles running from any directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + find_polytomies_script = os.path.join(script_dir, "find_polytomies.py") + + # Run find_polytomies.py and redirect output to the temp file + subprocess.run([ + "python", find_polytomies_script, + "-t", tree_file + ], stdout=open(temp_file, 'w'), check=True) + + # Parse the output to get the count + with open(temp_file, 'r') as f: + first_line = f.readline().strip() + match = re.search(r'Found (\d+) polytomies', first_line) + count = int(match.group(1)) if match else 0 + + return count + except Exception as e: + logger.error(f"Error in count_polytomies: {e}") + return 0 + finally: + # Clean up + if os.path.exists(temp_file): + os.remove(temp_file) + +def process_chunk(chunk_data): + """Process a chunk of polytomies in a separate process""" + input_tree, output_tree, min_depth, min_size, max_size, chunk_id, chunk_start, chunk_end, delay = chunk_data + + try: + # Get script paths + script_dir = os.path.dirname(os.path.abspath(__file__)) + find_polytomies_script = os.path.join(script_dir, "find_polytomies.py") + resolve_polytomies_script = os.path.join(script_dir, "resolve_polytomies.py") + + # Create a temporary file for the polytomies + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp: + polytomies_file = temp.name + + # Add initial delay with some jitter to avoid all processes starting at once + if delay > 0: + jitter = random.uniform(0, delay * 0.5) # Add up to 50% jitter + logger.info(f"Thread {chunk_id}: Waiting {delay + jitter:.2f}s before starting") + time.sleep(delay + jitter) + + # Find all polytomies first + subprocess.run([ + "python", find_polytomies_script, + "-t", input_tree, + "--min-depth", str(min_depth), + "--min-size", str(min_size), + "--max-size", str(max_size) + ], stdout=open(polytomies_file, 'w'), check=True) + + # Process only this chunk + logger.info(f"Thread {chunk_id}: Processing polytomies {chunk_start} to {chunk_end}") + + # Add --delay parameter if delay is specified + cmd = [ + "python", resolve_polytomies_script, + "-t", input_tree, + "-p", polytomies_file, + "-o", output_tree, + "--min-depth", str(min_depth), + "--min-size", str(min_size), + "--max-size", str(max_size), + "--chunk-start", str(chunk_start), + "--chunk-end", str(chunk_end) + ] + + # Add delay parameter if needed + if delay > 0: + cmd.extend(["--delay", str(delay)]) + + subprocess.run(cmd, check=True) + + # Clean up + os.remove(polytomies_file) + + logger.info(f"Thread {chunk_id}: Completed processing") + return True + + except Exception as e: + logger.error(f"Thread {chunk_id} error: {e}") + return False + +def run_pass_parallel(input_tree, output_tree, pass_number, min_depth, min_size, max_size, threads, delay=0): + """Run a single pass of polytomy resolution using parallel processing""" + # Get script paths + script_dir = os.path.dirname(os.path.abspath(__file__)) + find_polytomies_script = os.path.join(script_dir, "find_polytomies.py") + + # Create a temporary file for the polytomies + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp: + polytomies_file = temp.name + + try: + # Find polytomies for this pass + subprocess.run([ + "python", find_polytomies_script, + "-t", input_tree, + "--min-depth", str(min_depth), + "--min-size", str(min_size), + "--max-size", str(max_size) + ], stdout=open(polytomies_file, 'w'), check=True) + + # Count initial polytomies for this pass + polytomy_count = 0 + with open(polytomies_file, 'r') as f: + first_line = f.readline().strip() + match = re.search(r'Found (\d+) polytomies', first_line) + polytomy_count = int(match.group(1)) if match else 0 + + if polytomy_count == 0: + logger.info(f"Pass {pass_number}: No polytomies found with specified parameters") + # Copy input tree to output + subprocess.run(["cp", input_tree, output_tree]) + return 0, 0, 0 + + # If delay is enabled, apply a more reasonable thread cap + if delay > 0: + # Allow more threads even with longer delays + max_threads = max(1, min(threads, 4)) # Use up to 4 threads with delay + threads_to_use = min(max_threads, polytomy_count) + logger.info(f"Using {threads_to_use} threads with {delay}s delay between requests") + else: + threads_to_use = min(threads, polytomy_count, multiprocessing.cpu_count()) + + if threads_to_use <= 1 or polytomy_count <= 5: + # For small numbers of polytomies, just use single process + logger.info(f"Processing {polytomy_count} polytomies in a single thread") + resolve_script = os.path.join(script_dir, "resolve_polytomies.py") + + cmd = [ + "python", resolve_script, + "-t", input_tree, + "-p", polytomies_file, + "-o", output_tree, + "--min-depth", str(min_depth), + "--min-size", str(min_size), + "--max-size", str(max_size) + ] + + # Add delay parameter if needed + if delay > 0: + cmd.extend(["--delay", str(delay)]) + + subprocess.run(cmd, check=True) + else: + # Use multithreading - divide polytomies into chunks + logger.info(f"Processing {polytomy_count} polytomies with {threads_to_use} threads") + + # Calculate chunk size + chunk_size = max(1, polytomy_count // threads_to_use) + + # Create temp output files for each thread + temp_outputs = [] + for i in range(threads_to_use): + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.thread{i}.tre') as temp_out: + temp_outputs.append(temp_out.name) + + # Prepare arguments for each thread + chunk_args = [] + for i in range(threads_to_use): + chunk_start = i * chunk_size + chunk_end = min((i + 1) * chunk_size, polytomy_count) + + # Skip empty chunks + if chunk_start >= chunk_end: + continue + + chunk_args.append(( + input_tree, + temp_outputs[i], + min_depth, + min_size, + max_size, + i + 1, # chunk id for logging + chunk_start, + chunk_end, + delay # Add delay parameter + )) + + # Process chunks in parallel + start_time = time.time() + with multiprocessing.Pool(processes=threads_to_use) as pool: + results = pool.map(process_chunk, chunk_args) + + logger.info(f"Parallel processing completed in {time.time() - start_time:.2f} seconds") + + # Merge the results by finding a successful output + merged = False + for i, success in enumerate(results): + if success and os.path.exists(temp_outputs[i]): + # Copy this file to the output + subprocess.run(["cp", temp_outputs[i], output_tree]) + merged = True + break + + # If no successful output, fall back to the input tree + if not merged: + logger.warning("No successful thread outputs, using input tree") + subprocess.run(["cp", input_tree, output_tree]) + + # Clean up temp files + for tmp_file in temp_outputs: + if os.path.exists(tmp_file): + os.remove(tmp_file) + + # Count remaining polytomies + final_count = count_polytomies(output_tree) + + # Calculate how many were resolved + resolved = polytomy_count - final_count + + return polytomy_count, resolved, final_count + + except Exception as e: + logger.error(f"Error in run_pass_parallel: {e}") + # Ensure we have an output file + subprocess.run(["cp", input_tree, output_tree]) + return 0, 0, 0 + finally: + # Clean up + if os.path.exists(polytomies_file): + os.remove(polytomies_file) + +def multi_pass_resolve(input_tree, output_tree, threads=1, save_intermediate=False, keep_temp=False, delay=0): + """Run multiple passes of polytomy resolution""" + # Define passes: most reliable to least reliable + passes = [ + {"name": "Deep genus/family groups", "min_depth": 5, "min_size": 3, "max_size": 40}, + {"name": "Mid-level groups", "min_depth": 3, "min_size": 3, "max_size": 60}, + {"name": "Shallow order/class groups", "min_depth": 1, "min_size": 3, "max_size": 80}, + {"name": "All remaining groups", "min_depth": 0, "min_size": 2, "max_size": 150} + ] + + # Count initial polytomies + total_initial = count_polytomies(input_tree) + logger.info(f"Starting with {total_initial} polytomies in the tree") + logger.info(f"Using {threads} threads for parallel processing") + if delay > 0: + logger.info(f"Rate limiting: {delay} seconds delay between API requests") + + # Create a temporary file for intermediate results + current_tree = input_tree + intermediate_trees = [] + temp_files = [] # Track all temporary files + + # Run each pass + total_resolved = 0 + + for i, pass_config in enumerate(passes): + pass_num = i + 1 + pass_name = pass_config["name"] + min_depth = pass_config["min_depth"] + min_size = pass_config["min_size"] + max_size = pass_config["max_size"] + + logger.info(f"\n--- PASS {pass_num}: {pass_name} ---") + logger.info(f"Parameters: min_depth={min_depth}, min_size={min_size}, max_size={max_size}") + + # Create output file for this pass + if save_intermediate: + # Create a named intermediate file + output_base = os.path.splitext(output_tree)[0] + temp_output = f"{output_base}_pass{pass_num}.tre" + intermediate_trees.append(temp_output) + else: + # Use a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.pass{pass_num}.tre') as temp: + temp_output = temp.name + temp_files.append(temp_output) + logger.info(f"Pass {pass_num} output tree: {temp_output}") + + # Run this pass with parallel processing + initial, resolved, remaining = run_pass_parallel( + current_tree, temp_output, pass_num, min_depth, min_size, max_size, threads, delay + ) + + # Update total resolved + total_resolved += resolved + + logger.info(f"Pass {pass_num} results:") + logger.info(f" - Targeted polytomies: {initial}") + logger.info(f" - Resolved in this pass: {resolved}") + logger.info(f" - Resolution rate for pass: {(resolved/initial*100) if initial > 0 else 0:.1f}%") + logger.info(f" - Total resolved so far: {total_resolved}") + logger.info(f" - Overall resolution: {(total_resolved/total_initial*100) if total_initial > 0 else 0:.1f}%") + + if save_intermediate: + logger.info(f" - Saved intermediate tree to: {temp_output}") + + # Update current tree for next pass + current_tree = temp_output + + # Final pass - ensure tree is fully bifurcating with format_constraint_tree.py + logger.info("\n--- FINAL PASS: Force Bifurcation ---") + + # Get script path + script_dir = os.path.dirname(os.path.abspath(__file__)) + format_script = os.path.join(script_dir, "format_constraint_tree.py") + + subprocess.run([ + "python", format_script, + "-i", current_tree, + "-o", output_tree, + "-m", "bifurcate" + ], check=True) + + # Count final polytomies to verify + final_polytomies = count_polytomies(output_tree) + logger.info(f"\n===== FINAL RESULTS =====") + logger.info(f"Initial polytomies: {total_initial}") + logger.info(f"Resolved by OpenToL: {total_resolved}") + logger.info(f"Remaining after forced bifurcation: {final_polytomies}") + logger.info(f"Resolution rate: {(1 - final_polytomies/total_initial)*100 if total_initial > 0 else 100:.1f}%") + + # List all temporary files if requested + if keep_temp and temp_files: + logger.info("\n===== TEMPORARY FILES =====") + for i, file in enumerate(temp_files): + if os.path.exists(file): + logger.info(f"{i+1}. {file}") + logger.info("==========================\n") + + # Clean up temp files if not keeping them + if not keep_temp: + for file in temp_files: + if os.path.exists(file) and file != input_tree: + os.remove(file) + else: + logger.info("Temporary files have been kept") + +def main(): + parser = argparse.ArgumentParser(description="Multi-pass polytomy resolution") + parser.add_argument("-t", "--tree", required=True, help="Input tree file") + parser.add_argument("-o", "--output", required=True, help="Output tree file") + parser.add_argument("-j", "--threads", type=int, default=multiprocessing.cpu_count(), + help=f"Number of threads to use (default: {multiprocessing.cpu_count()})") + parser.add_argument("-i", "--intermediate", action="store_true", + help="Save intermediate trees") + parser.add_argument("--keep-temp", action="store_true", + help="Keep temporary files (don't delete them)") + parser.add_argument("--slow", type=float, default=0, + help="Add delay between API requests in seconds (default: 0, no delay)") + parser.add_argument("-v", "--verbose", action="store_true", + help="Enable verbose output") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Create log file + log_file = f"multi_pass_resolve_{time.strftime('%Y%m%d_%H%M%S')}.log" + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(file_handler) + + logger.info(f"Starting multi-pass resolution with {args.threads} threads") + if args.slow > 0: + logger.info(f"Rate limiting enabled: {args.slow}s delay between API requests") + logger.info(f"Log file: {log_file}") + + start_time = time.time() + + try: + multi_pass_resolve(args.tree, args.output, args.threads, args.intermediate, args.keep_temp, args.slow) + elapsed = time.time() - start_time + logger.info(f"Total execution time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)") + except Exception as e: + logger.error(f"Critical error: {e}") + import traceback + logger.error(traceback.format_exc()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workflow/scripts/opentol.py b/workflow/scripts/opentol.py index 84bff6a..2b94b99 100644 --- a/workflow/scripts/opentol.py +++ b/workflow/scripts/opentol.py @@ -109,7 +109,9 @@ def _opentol_request(ids): def get_subtree(idmap): + logger.info(f"Going to fetch subtree for {idmap}") json_result = _iterate_requests(idmap) + logger.info(f"Received {json_result}") # Parse the newick string, which may still have mrca nodes tree_obj = dendropy.Tree.get( diff --git a/workflow/scripts/prep_raxml.py b/workflow/scripts/prep_raxml.py index c8fc1db..754818d 100644 --- a/workflow/scripts/prep_raxml.py +++ b/workflow/scripts/prep_raxml.py @@ -40,6 +40,8 @@ def make_constraint(intree, outtree, processmap): tree = read_newick(intree, 'newick') except ValueError: logger.info("No trees made for this family.") + with open(outtree, 'a'): + pass return # Map opentol_id to process_id, possibly adding tips if there are multiple process_ids @@ -83,6 +85,7 @@ def make_mapping(aln, conn): map_dict = {} for seq in aln: process_id = seq.id + logger.debug(f"Processing sequence {process_id}") # Because we are querying on the basis of the alignment, we may encounter cases # where there is a process_id without an opentol_id. However, this is not going @@ -99,6 +102,7 @@ def make_mapping(aln, conn): record = cursor.fetchone() # Check if record is not empty + logger.debug(f"Query result for {process_id} is: {record}") if record is not None: opentol_id = f'ott{record[0]}' # tree has ott prefixes if opentol_id not in map_dict: @@ -131,6 +135,7 @@ def make_mapping(aln, conn): infile = os.path.realpath(os.path.abspath(args.inaln)) try: alignment = read_alignment(infile, 'fasta') + logger.info(f"Read {len(alignment)} records from {infile}") except: logger.info("No records in the family.") alignment = [] diff --git a/workflow/scripts/backbone_constraint.py b/workflow/scripts/prep_raxml_backbone.py similarity index 80% rename from workflow/scripts/backbone_constraint.py rename to workflow/scripts/prep_raxml_backbone.py index 5c4118a..5102692 100644 --- a/workflow/scripts/backbone_constraint.py +++ b/workflow/scripts/prep_raxml_backbone.py @@ -1,4 +1,5 @@ import argparse +import subprocess import util import os.path import sqlite3 @@ -9,7 +10,7 @@ """ -This script, `backbone_constraint.py`, is responsible for generating a constraint tree for a given family from a SQLite +This script, `prep_raxml_backbone.py`, is responsible for generating a constraint tree for a given family from a SQLite database and a set of FASTA files. The script performs the following steps: @@ -142,21 +143,27 @@ def remap_tips(tree, pidmap): # Define command line arguments parser = argparse.ArgumentParser(description='Required command line arguments.') parser.add_argument('-d', '--database', required=True, help='SQLite database file') - parser.add_argument('-i', '--inaln', required=True, help='Input exemplar FASTA files') + parser.add_argument('-i', '--input_list', required=True, help='Text file containing list of input exemplar FASTA files') parser.add_argument('-o', '--outtree', required=True, help="Output constraint tree") parser.add_argument('-e', '--extinctpids', required=True, help='Putatively extinct PIDs') parser.add_argument('-v', '--verbosity', required=True, help='Log level (e.g. DEBUG)') + parser.add_argument('-hmm', '--hmmfile', required=True, help='HMM file for alignment') + parser.add_argument('-f', '--fasta', required=True, help='Output FASTA file') args = parser.parse_args() # Configure logging logger = util.get_formatted_logger('backbone_constraint', args.verbosity) + # Read input exemplar FASTA files from the provided text file + with open(args.input_list, 'r') as f: + exemplar_files = [line.strip() for line in f.readlines()] + # Configure database connection logger.info(f"Going to connect to database {args.database}") connection = sqlite3.connect(args.database) # Get one-to-many mapping from OTT IDs to process IDs and store extinct PIDs - pidmap, extinctpids = process_exemplars(args.inaln.split(' '), connection) + pidmap, extinctpids = process_exemplars(exemplar_files, connection) connection.close() if len(extinctpids) != 0: with open(args.extinctpids, 'w') as file: @@ -174,3 +181,31 @@ def remap_tips(tree, pidmap): with open(args.outtree, "w") as output_file: output_file.write(ott_tree.as_string(schema="newick")) + # Clean the concatenated FASTA by removing gaps (dashes) + unaligned_fasta = 'results/fasta/unaligned.fa' + with open(unaligned_fasta, 'w') as unaligned_file: + for fasta in exemplar_files: + with open(fasta, 'r') as input_file: + for line in input_file: + if line.startswith('>'): + unaligned_file.write(line) + else: + unaligned_file.write(line.replace('-', '')) + + # Align with hmmalign and output in Stockholm format + aligned_sto = 'results/fasta/aligned.sto' + subprocess.run([ + 'hmmalign', '--trim', '--dna', '--informat', 'FASTA', '--outformat', 'Stockholm', + '-o', aligned_sto, args.hmmfile, unaligned_fasta + ]) + + # Convert the Stockholm alignment to a non-interleaved FASTA format for RAxML + subprocess.run([ + 'seqmagick', 'convert', aligned_sto, args.fasta + ]) + + # Remove any extinct PIDs + if os.path.exists(args.extinctpids): + subprocess.run([ + 'seqmagick', 'mogrify', '--exclude-from-file', args.extinctpids, args.fasta + ]) \ No newline at end of file diff --git a/workflow/scripts/process_failed_file.py b/workflow/scripts/process_failed_file.py new file mode 100644 index 0000000..136b0af --- /dev/null +++ b/workflow/scripts/process_failed_file.py @@ -0,0 +1,30 @@ +import argparse +from Bio import SeqIO + + +def remove_outgroup_records(failed_file, outgroup_file, output_file): + # Read outgroup sequences + outgroup_records = set() + if outgroup_file: + with open(outgroup_file, "r") as outgroup: + for record in SeqIO.parse(outgroup, "fasta"): + outgroup_records.add(record.id) + + # Filter failed file records + with open(failed_file, "r") as failed, open(output_file, "w") as output: + for record in SeqIO.parse(failed, "fasta"): + if record.id not in outgroup_records: + # Ensure a sequence is written in one line + record.seq = record.seq.replace("-", "") + SeqIO.write(record, output, "fasta-2line") # Use "fasta-2line" to ensure one-line sequences + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process failed file by removing outgroup records and ensuring one-line sequences.") + parser.add_argument("--failed", required=True, help="Path to the failed file (FASTA format).") + parser.add_argument("--outgroup", required=False, help="Path to the outgroup file (FASTA format).") + parser.add_argument("--output", required=True, help="Path to the output file (FASTA format).") + args = parser.parse_args() + + remove_outgroup_records(args.failed, args.outgroup, args.output) \ No newline at end of file diff --git a/workflow/scripts/reroot_backbone.py b/workflow/scripts/reroot_backbone.py index f4ac100..d73aaf9 100644 --- a/workflow/scripts/reroot_backbone.py +++ b/workflow/scripts/reroot_backbone.py @@ -128,4 +128,3 @@ def find_set_bipartition(tree, query_set): file.write(newick_string) - diff --git a/workflow/scripts/resolve_polytomies.py b/workflow/scripts/resolve_polytomies.py new file mode 100644 index 0000000..be3f7ab --- /dev/null +++ b/workflow/scripts/resolve_polytomies.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +""" +Resolve Polytomies Using OpenTree of Life API + +This script resolves polytomies (multifurcating nodes) in a phylogenetic tree by +replacing them with better-resolved subtrees from the OpenTree of Life database. +The process involves: +1. Identifying polytomies in the input tree +2. Finding the taxonomic name associated with each polytomy +3. Querying the OpenToL API to get the taxonomic ID +4. Retrieving a resolved subtree for that taxon +5. Replacing the polytomy with the resolved subtree + +Input: + - Constraint tree (Newick format) + - List of polytomies (.txt file, see output from find_polytomies.py) + +Output: + - Resolved constraint tree with improved resolution at polytomies + +Usage: + python resolve_polytomies.py -t tree.tre -p polytomies.txt -o resolved_tree.tre [options] + +Dependencies: + - ETE3 (Python tree manipulation) + - Requests (API communication) + - Internet connection (for OpenToL API access) +""" + +import argparse +import re +import ast +from ete3 import Tree +import sys +import logging +import requests +import json +import tempfile +import os + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger('resolve_polytomies') + +def parse_polytomies_file(filename): + """ + Parse the output file from find_polytomies.py + + Extracts lists of tip labels for each polytomy from the output format + of find_polytomies.py. + + Args: + filename: Path to polytomies file + + Returns: + list: Lists of tip labels for each polytomy + """ + polytomies = [] + + with open(filename, 'r') as f: + lines = f.readlines() + + # Skip the first line (count summary) + for i, line in enumerate(lines[1:], 1): + # Extract the list part using regex with non-greedy matching to avoid capturing extra content + match = re.search(r'Polytomy \d+: (\[.*?\]) \(depth=', line) + if match: + # Convert string representation of list to actual list + try: + tips = ast.literal_eval(match.group(1)) + polytomies.append(tips) + except (SyntaxError, ValueError): + logger.warning(f"Could not parse polytomy on line {i+1}") + + return polytomies + +def find_node_with_tips(tree, tips): + """ + Find the node in the tree that contains exactly these tips + + Args: + tree: ETE3 Tree object + tips: List of tip labels to find + + Returns: + TreeNode: Node containing exactly these tips, or None if not found + """ + for node in tree.traverse(): + if set(node.get_leaf_names()) == set(tips): + return node + return None + +def find_taxonomic_name(node): + """ + Extract taxonomic name from node or its parent nodes + + This function tries to find a taxonomic name (not formatted as a barcode ID) + associated with the node or its ancestors. + + Args: + node: An ETE3 TreeNode object + + Returns: + str: Taxonomic name if found, None otherwise + """ + if node.name and node.name.strip() and not re.match(r'^[A-Z]+\d+-\d+$', node.name): + return node.name + + # If no name at this node, try parent nodes until we find a taxonomic name + parent = node.up + while parent: + if parent.name and parent.name.strip() and not re.match(r'^[A-Z]+\d+-\d+$', parent.name): + return parent.name + parent = parent.up + + return None + +def get_node_depth(node): + """ + Get the depth of a node in the tree (distance from root) + + Args: + node: An ETE3 TreeNode object + + Returns: + int: Depth of the node (0 for root, increases with distance from root) + """ + depth = 0 + current = node + while current.up: + depth += 1 + current = current.up + return depth + +def get_ott_id(taxon_name): + """ + Query OpenTOL TNRS API to get OTT ID for a taxonomic name + + Uses the Taxonomic Name Resolution Service to find the OpenTree Taxonomy ID + for a given taxonomic name. + + Args: + taxon_name: Taxonomic name to look up + + Returns: + int: OpenTree Taxonomy ID if found, None otherwise + """ + url = "https://api.opentreeoflife.org/v3/tnrs/match_names" + payload = { + "names": [taxon_name], + "do_approximate_matching": True + } + + try: + logger.info(f"Looking up OTT ID for '{taxon_name}'") + response = requests.post(url, json=payload) + data = response.json() + + if 'results' in data and len(data['results']) > 0 and len(data['results'][0]['matches']) > 0: + ott_id = data['results'][0]['matches'][0]['taxon']['ott_id'] + logger.info(f"Found OTT ID for '{taxon_name}': {ott_id}") + return ott_id + else: + logger.warning(f"No OTT ID found for '{taxon_name}'") + return None + except Exception as e: + logger.error(f"Error querying OTT ID: {e}") + return None + +def get_resolved_subtree(ott_id): + """ + Get a resolved subtree from OpenTOL using an OTT ID + + Queries the OpenTree of Life API to get a resolved subtree for a given + taxonomic ID. + + Args: + ott_id: OpenTree Taxonomy ID + + Returns: + str: Newick string of resolved subtree if found, None otherwise + """ + url = "https://api.opentreeoflife.org/v3/tree_of_life/subtree" + payload = { + "ott_id": ott_id, + "label_format": "name" + } + + try: + logger.info(f"Requesting subtree for OTT ID {ott_id}") + response = requests.post(url, json=payload) + data = response.json() + + if 'newick' in data: + return data['newick'] + else: + logger.warning(f"No subtree found for OTT ID {ott_id}") + return None + except Exception as e: + logger.error(f"Error getting subtree: {e}") + return None + +def parse_newick_safely(newick_str): + """ + Parse a Newick string safely with multiple methods + + Tries multiple approaches to parse potentially problematic Newick strings, + including handling quoted names, special characters, and various formats. + + Args: + newick_str: Newick format string to parse + + Returns: + Tree: ETE3 Tree object if parsing succeeds, None otherwise + """ + def sanitize_newick(text): + """Clean up problematic characters in newick strings""" + # Replace problematic patterns + text = re.sub(r"'([^']*)'", r"\1", text) # Remove single quotes + text = re.sub(r'"([^"]*)"', r"\1", text) # Remove double quotes + text = re.sub(r'nr\.\s+', r"nr_", text) # Fix "nr. " pattern + text = re.sub(r'\s+sp\.\s+', r"_sp_", text) # Fix " sp. " pattern + text = re.sub(r'\s+', r"_", text) # Replace spaces with underscores + + # Add quotes around all taxonomic names for safety + text = re.sub(r'([A-Za-z][A-Za-z0-9_.:-]+)', r"'\1'", text) + return text + + # Try multiple parsing approaches + try: + # First try simple parsing with quoted names + return Tree(newick_str, format=1, quoted_node_names=True) + except Exception: + try: + # Try format 0 (most flexible) + return Tree(newick_str, format=0) + except Exception: + # Last resort: write to temp file with sanitized content + sanitized = sanitize_newick(newick_str) + + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp: + temp.write(sanitized) + temp_file = temp.name + + try: + # Try to parse the sanitized newick from file + tree = Tree(temp_file, format=1, quoted_node_names=True) + os.remove(temp_file) + return tree + except Exception: + if os.path.exists(temp_file): + os.remove(temp_file) + + # If we still can't parse it, log more info and return None + logger.debug(f"Raw newick: {newick_str[:100]}...") + logger.debug(f"Sanitized: {sanitized[:100]}...") + return None + +def resolve_polytomy(constraint_tree, polytomy_tips, min_depth=3): + """ + Resolve a polytomy using the OpenTOL API + + Takes a set of tips defining a polytomy and attempts to replace the + corresponding node with a better-resolved subtree from OpenTree of Life. + + Args: + constraint_tree: ETE3 Tree object containing the polytomy + polytomy_tips: List of tip labels in the polytomy + min_depth: Minimum depth to consider (skip shallower nodes) + + Returns: + tuple: ( + success (bool), + taxon_name (str or None), + ott_id (int or None), + newick_str (str or None) + ) + """ + # Find the node containing these tips + polytomy_node = find_node_with_tips(constraint_tree, polytomy_tips) + if not polytomy_node: + logger.warning(f"Could not find node for tips in tree") + return False, None, None, None + + # Skip polytomies that are too shallow in the tree + node_depth = get_node_depth(polytomy_node) + if node_depth < min_depth: + logger.info(f"Skipping shallow polytomy (depth {node_depth} < {min_depth})") + return False, None, None, None + + # Get taxonomic name for this node + taxon_name = find_taxonomic_name(polytomy_node) + if not taxon_name: + logger.warning(f"Could not find taxonomic name for node") + return False, None, None, None + + # Get OTT ID for this taxon + ott_id = get_ott_id(taxon_name) + if not ott_id: + # Try parent taxon if this one isn't found + if polytomy_node.up: + parent_name = find_taxonomic_name(polytomy_node.up) + if parent_name and parent_name != taxon_name: + logger.info(f"Trying parent taxon '{parent_name}'") + ott_id = get_ott_id(parent_name) + + if not ott_id: + logger.warning(f"Could not find OTT ID for taxon '{taxon_name}'") + return False, taxon_name, None, None + + # Get resolved subtree + newick_str = get_resolved_subtree(ott_id) + if not newick_str: + logger.warning(f"Could not get subtree for OTT ID {ott_id}") + return False, taxon_name, ott_id, None + + # Replace polytomy with resolved subtree + try: + # Parse the resolved subtree with proper handling of quoted names + resolved_subtree = parse_newick_safely(newick_str) + + if not resolved_subtree: + logger.warning(f"Failed to parse resolved subtree") + return False, taxon_name, ott_id, newick_str + + # Check if the resolved subtree has a better structure + if len(resolved_subtree.get_children()) <= len(polytomy_node.get_children()): + logger.info(f"Resolved subtree doesn't improve structure, skipping") + return False, taxon_name, ott_id, newick_str + + # Replace the polytomy with the resolved subtree + parent = polytomy_node.up + if parent: + # Remove old node + polytomy_node.detach() + + # Add new subtree (fix: removed 'pos' parameter) + parent.add_child(resolved_subtree) + logger.info(f"Successfully replaced polytomy with resolved subtree") + return True, taxon_name, ott_id, newick_str + else: + # Special handling for root node + logger.warning(f"Cannot replace root node") + return False, taxon_name, ott_id, newick_str + except Exception as e: + logger.error(f"Error replacing polytomy: {e}") + return False, taxon_name, ott_id, newick_str + +def main(): + """ + Main function: Parse arguments and process polytomies + """ + parser = argparse.ArgumentParser(description="Resolve polytomies using OpenTOL API") + parser.add_argument("-t", "--tree", required=True, + help="Path to the original constraint tree") + parser.add_argument("-p", "--polytomies", required=True, + help="Path to polytomies file (output from find_polytomies.py)") + parser.add_argument("-o", "--output", required=True, + help="Path for the resolved constraint tree output") + parser.add_argument("-l", "--limit", type=int, default=None, + help="Limit the number of polytomies to process (for testing)") + parser.add_argument("--min-size", type=int, default=3, + help="Minimum size of polytomy to process (default: 3)") + parser.add_argument("--max-size", type=int, default=50, + help="Maximum size of polytomy to process (default: 50)") + parser.add_argument("--min-depth", type=int, default=3, + help="Minimum depth in tree for polytomy to process (default: 3)") + parser.add_argument("-v", "--verbose", action="store_true", + help="Enable verbose output") + + args = parser.parse_args() + + # Set logging level + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Load the constraint tree + try: + logger.info(f"Loading constraint tree from {args.tree}") + constraint_tree = Tree(args.tree, format=1) + except Exception as e: + logger.error(f"Failed to load tree: {e}") + sys.exit(1) + + # Parse the polytomies file + logger.info(f"Parsing polytomies from {args.polytomies}") + polytomies = parse_polytomies_file(args.polytomies) + logger.info(f"Found {len(polytomies)} polytomies") + + # Process each polytomy + processed = 0 + resolved = 0 + no_ott_id = [] + failed_parsing = [] + other_fails = [] + + for i, tips in enumerate(polytomies): + # Check if we've reached the limit + if args.limit and i >= args.limit: + break + + # Skip polytomies that are too large or too small + if len(tips) < args.min_size or len(tips) > args.max_size: + logger.info(f"Skipping polytomy {i+1} with {len(tips)} tips (size out of range)") + continue + + logger.info(f"Processing polytomy {i+1}/{len(polytomies)} with {len(tips)} tips") + + # Try to resolve this polytomy + success, taxon_name, ott_id, newick_str = resolve_polytomy(constraint_tree, tips, min_depth=args.min_depth) + processed += 1 + + if success: + resolved += 1 + elif taxon_name and not ott_id: + no_ott_id.append(taxon_name) + elif ott_id and newick_str and not success: + failed_parsing.append(taxon_name) + elif taxon_name: + other_fails.append(taxon_name) + + # Save the resolved tree + logger.info(f"Saving resolved tree to {args.output}") + constraint_tree.write(format=1, outfile=args.output) + + # Generate a summary report + logger.info(f"\n===== SUMMARY =====") + logger.info(f"Total polytomies processed: {processed}") + logger.info(f"Successfully resolved: {resolved} ({(resolved/processed*100) if processed > 0 else 0:.1f}%)") + logger.info(f"Taxa with no OTT ID: {len(no_ott_id)}") + logger.info(f"Taxa with parsing errors: {len(failed_parsing)}") + logger.info(f"Taxa that failed for other reasons: {len(other_fails)}") + logger.info(f"===================\n") + +if __name__ == "__main__": + main() + +""" +Example usage: + +# Basic usage +python resolve_polytomies.py -t constraint_tree.tre -p polytomies.txt -o resolved_tree.tre + +# Focus on smaller polytomies at deeper taxonomic levels +python resolve_polytomies.py -t constraint_tree.tre -p polytomies.txt -o resolved_tree.tre --min-depth 5 --min-size 3 --max-size 25 + +# Process only a limited number of polytomies (for testing) +python resolve_polytomies.py -t constraint_tree.tre -p polytomies.txt -o resolved_tree.tre -l 10 -v + +Workflow: +1. Generate polytomies file: python find_polytomies.py -t tree.tre --min-depth 4 > polytomies.txt +2. Resolve polytomies: python resolve_polytomies.py -t tree.tre -p polytomies.txt -o resolved_tree.tre +3. Use resolved tree as a constraint tree for phylogenetic analyses (RAxML, etc.) +""" \ No newline at end of file diff --git a/workflow/scripts/run_raxml.py b/workflow/scripts/run_raxml.py new file mode 100644 index 0000000..bda6616 --- /dev/null +++ b/workflow/scripts/run_raxml.py @@ -0,0 +1,65 @@ +import argparse +import subprocess +import os + +def run_raxml(alignment, tree, output, model, num_outgroups, log_file): + # Count the number of taxa in the alignment + with open(alignment, 'r') as f: + taxon_count = sum(1 for line in f if line.startswith('>')) + + # Extract the outgroup names + with open(alignment, 'r') as f: + lines = [line.strip() for line in f if line.startswith('>')] + outgroups = ",".join(line[1:] for line in lines[-num_outgroups:]) + + # Check the constraint tree's properties + constraint_tree_exists = os.path.exists(tree) and os.path.getsize(tree) > 0 + constraint_tree_valid = False + if constraint_tree_exists: + with open(tree, 'r') as f: + content = f.read() + constraint_tree_valid = ('(' in content and ')' in content and + content.count('(') > 2 and + content.count(',') < (taxon_count - 1)) + + # Prepare the raxml-ng command + if constraint_tree_valid: + cmd = [ + "raxml-ng", "--redo", "--msa", alignment, "--outgroup", outgroups, + "--model", model, "--tree-constraint", tree, "--search" + ] + log_message = "Running RAxML-NG with constraint tree" + else: + cmd = [ + "raxml-ng", "--redo", "--msa", alignment, "--model", model, + "--search" + ] + log_message = "Constraint tree fully-resolved, running RAxML-NG without it" + + # Log the process and execute the command + with open(log_file, 'w') as log: + log.write(log_message + "\n") + try: + subprocess.run(cmd, check=True, stdout=log, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + log.write(f"RAxML-NG failed: {e}\n") + raise + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run RAxML-NG with or without a constraint tree.") + parser.add_argument("--alignment", required=True, help="Path to the alignment file.") + parser.add_argument("--tree", required=True, help="Path to the constraint tree file.") + parser.add_argument("--output", required=True, help="Path to the output tree file.") + parser.add_argument("--model", required=True, help="Model to use for RAxML-NG.") + parser.add_argument("--num_outgroups", type=int, required=True, help="Number of outgroups.") + parser.add_argument("--log_file", required=True, help="Path to the log file.") + args = parser.parse_args() + + run_raxml( + alignment=args.alignment, + tree=args.tree, + output=args.output, + model=args.model, + num_outgroups=args.num_outgroups, + log_file=args.log_file + ) diff --git a/workflow/scripts/run_raxml_backbone.py b/workflow/scripts/run_raxml_backbone.py new file mode 100644 index 0000000..3880476 --- /dev/null +++ b/workflow/scripts/run_raxml_backbone.py @@ -0,0 +1,66 @@ +import subprocess +import sys +import os + +def run_raxml_backbone(alignment, tree, model, log_file): + # Run the initial raxml-ng command + cmd = [ + "raxml-ng", + "--redo", + "--msa", alignment, + "--model", model, + "--tree-constraint", tree, + "--search", + "--msa-format", "PHYLIP" + ] + with open(log_file, "w") as log: + log.write("Running initial raxml-ng command:\n") + log.write(" ".join(cmd) + "\n") + try: + result = subprocess.run(cmd, stdout=log, stderr=log, check=True) + except subprocess.CalledProcessError as e: + log.write(f"\nCommand failed with return code {e.returncode}\n") + log.write(f"Command output: {e.output}\n") + return 2 # Indicate that an error occurred + + # Check if the log file contains the specific error message + with open(log_file, "r") as log: + log_content = log.read() + if "ERROR: You provided a comprehensive, fully-resolved tree as a topological constraint." in log_content: + # Run the branch length optimization + cmd = [ + "raxml-ng", + "--evaluate", + "--redo", + "--msa", alignment, + "--model", model, + "--tree", tree, + "--brlen", "scaled" + ] + with open(log_file, "a") as log: + log.write("\nRunning branch length optimization:\n") + log.write(" ".join(cmd) + "\n") + try: + result = subprocess.run(cmd, stdout=log, stderr=log, check=True) + except subprocess.CalledProcessError as e: + log.write(f"\nCommand failed with return code {e.returncode}\n") + log.write(f"Command output: {e.output}\n") + return 2 # Indicate that an error occurred + return 1 # Indicate that the branch length optimization was run + + # Check if the expected output file is created + output_file = alignment + ".raxml.bestTree" + if not os.path.exists(output_file): + with open(log_file, "a") as log: + log.write(f"\nExpected output file {output_file} not found.\n") + return 3 # Indicate that the output file is missing + + return 0 # Indicate that the initial command was successful + +if __name__ == "__main__": + alignment = sys.argv[1] + tree = sys.argv[2] + model = sys.argv[3] + log_file = sys.argv[4] + exit_code = run_raxml_backbone(alignment, tree, model, log_file) + sys.exit(exit_code) \ No newline at end of file diff --git a/workflow/scripts/treebuilder.py b/workflow/scripts/treebuilder.py new file mode 100644 index 0000000..5fbf2ef --- /dev/null +++ b/workflow/scripts/treebuilder.py @@ -0,0 +1,433 @@ +import os +import sys +import logging +import argparse +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Optional +import pandas as pd +from Bio import AlignIO, Phylo +from Bio.Phylo.BaseTree import Tree, Clade + +""" +BOLD Taxonomy Tree Builder + +This script builds a taxonomic tree from BOLD process IDs listed in a FASTA alignment file. +It extracts the process IDs from the FASTA file, looks up their taxonomic information in a +BOLD BCDM (Barcode Core Data Model) TSV file, and constructs a hierarchical tree with the +process IDs at the tips and taxonomic information at the internal nodes. + +Inputs: +------- +1. FASTA Alignment File (-f, --fasta): + - Standard FASTA format with BOLD process IDs as sequence identifiers + - The process ID should be the first word in the FASTA defline + - Example: >AAASF001-17 + +2. BOLD BCDM TSV File (-b, --bold): + - Standard BOLD BCDM TSV file containing taxonomic information + - Must include 'processid' column and taxonomic level columns (kingdom, phylum, etc.) + - See BOLD documentation for complete format specifications: + https://github.com/boldsystems-central/BCDM/blob/main/field_definitions.tsv + +Output: +------- +- Newick format tree file with process IDs at the tips and optional taxonomic labels at + internal nodes (-o, --output) + +Output Customization: +-------------------- +1. Internal Node Labels (-n, --nodelabels): + - By default, internal node labels are removed from the output + - Use --nodelabels to include taxonomic labels at internal nodes + - Note: Many phylogenetic programs do NOT support internal node labels in Newick files + - However, these labels can be useful when working with the Open Tree of Life web services + to further resolve polytomies based on known phylogenetic relationships + +2. Unbranched Internal Nodes (-c, --collapse): + - Use --collapse to remove unbranching internal nodes (nodes with only one child) + - Many phylogenetic analysis programs do NOT correctly process trees with unbranched + internal nodes and may produce errors or unexpected results + - Collapsing these nodes maintains the same tree topology from a phylogenetic perspective + +Usage Examples: +-------------- +Basic usage: + python treebuilder.py -f sequences.fasta -b bold_data.tsv -o taxonomy_tree.nwk + +With all options: + python treebuilder.py -f sequences.fasta -b bold_data.tsv -o taxonomy_tree.nwk -c -n -v + +Notes: +------ +- Branch lengths are removed from the output tree as they are not meaningful in this context +- Missing taxonomic levels (marked as 'None' or empty in BOLD data) are skipped +- The script handles large BOLD data files fairly efficiently by processing in chunks +""" + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger('bold-tree-builder') + +def extract_process_ids(fasta_file: str) -> Set[str]: + """ + Extract BOLD process IDs from a FASTA alignment file. + + :param fasta_file: Path to the FASTA alignment file + :return: A set of BOLD process IDs + """ + logger.info(f"Extracting process IDs from {fasta_file}") + process_ids = set() + + try: + alignment = AlignIO.read(fasta_file, "fasta") + for record in alignment: + process_id = record.id.strip() + if process_id: + process_ids.add(process_id) + + logger.info(f"Extracted {len(process_ids)} process IDs") + logger.debug(f"Process IDs: {', '.join(list(process_ids)[:5])}...") + return process_ids + + except Exception as e: + logger.error(f"Error reading FASTA file: {e}") + sys.exit(1) + + +def read_bold_taxonomy(bold_file: str, process_ids: Set[str]) -> pd.DataFrame: + """ + Read BOLD taxonomy information for the specified process IDs. + + :param bold_file: Path to the BOLD BCDM TSV file + :param process_ids: Set of BOLD process IDs to filter + :return: DataFrame containing taxonomy information for the specified process IDs + """ + logger.info(f"Reading taxonomy data from {bold_file}") + + try: + # Define taxonomic levels in hierarchical order + taxonomy_levels = [ + 'kingdom', 'phylum', 'class', 'order', + 'family', 'subfamily', 'genus', 'species', 'subspecies' + ] + + # Read the BOLD TSV file, focusing on taxonomic columns and processid + cols_to_read = ['processid'] + taxonomy_levels + + # Read the file in chunks to handle large files efficiently + df_chunks = pd.read_csv( + bold_file, + sep='\t', + usecols=cols_to_read, + dtype=str, + chunksize=100000 + ) + + # Combine chunks, filtering for our process IDs + records = [] + for chunk in df_chunks: + filtered_chunk = chunk[chunk['processid'].isin(process_ids)] + records.append(filtered_chunk) + + # If we've found all our process IDs, we can stop reading + if len(set.union(*[set(df['processid']) for df in records])) == len(process_ids): + logger.debug("Found all requested process IDs, stopping further reading") + break + + if not records: + logger.error("No matching records found in BOLD data") + sys.exit(1) + + taxonomy_df = pd.concat(records, ignore_index=True) + found_ids = set(taxonomy_df['processid']) + missing_ids = process_ids - found_ids + + if missing_ids: + logger.warning(f"Could not find {len(missing_ids)} process IDs in BOLD data") + logger.debug(f"Missing IDs: {', '.join(list(missing_ids)[:5])}...") + + logger.info(f"Found taxonomy data for {len(found_ids)} process IDs") + return taxonomy_df + + except Exception as e: + logger.error(f"Error reading BOLD taxonomy file: {e}") + sys.exit(1) + + +def build_taxonomy_paths(taxonomy_df: pd.DataFrame) -> Dict[str, List[str]]: + """ + Build taxonomy paths for each process ID. + + :param taxonomy_df: DataFrame containing taxonomy information + :return: Dictionary mapping process IDs to their taxonomy paths + """ + logger.info("Building taxonomy paths for each process ID") + + taxonomy_levels = [ + 'kingdom', 'phylum', 'class', 'order', + 'family', 'subfamily', 'genus', 'species', 'subspecies' + ] + + taxonomy_paths = {} + + for _, row in taxonomy_df.iterrows(): + process_id = row['processid'] + path = [] + + # Build path, skipping 'None' or empty values + for level in taxonomy_levels: + taxon = row[level] + if pd.notna(taxon) and taxon != 'None' and taxon.strip(): + path.append(taxon) + + # Add the process ID as the final element + path.append(process_id) + taxonomy_paths[process_id] = path + + logger.debug(f"Sample path: {list(taxonomy_paths.values())[0]}") + return taxonomy_paths + + +def build_tree_from_paths(taxonomy_paths: Dict[str, List[str]]) -> Tree: + """ + Build a BioPython tree from taxonomy paths. + + :param taxonomy_paths: Dictionary mapping process IDs to taxonomy paths + :return: BioPython Tree object + """ + logger.info("Building taxonomic tree") + + # Create the root node + tree = Tree(rooted=True, root=Clade(name="root")) + + # Track nodes by path to handle shared ancestry + path_to_clade = {"": tree.root} + + # Sort paths to ensure parent nodes are created before children + process_ids = sorted(taxonomy_paths.keys()) + + for process_id in process_ids: + path = taxonomy_paths[process_id] + current_path = "" + parent_path = "" + + # Process each level in the path except the last (which is the process ID) + for i, taxon in enumerate(path[:-1]): + current_path = f"{current_path}/{taxon}" if current_path else taxon + + # Create new node if this path doesn't exist yet + if current_path not in path_to_clade: + new_clade = Clade(name=taxon) + path_to_clade[parent_path].clades.append(new_clade) + path_to_clade[current_path] = new_clade + + parent_path = current_path + + # Add the leaf node (process ID) + leaf_clade = Clade(name=process_id) + path_to_clade[current_path].clades.append(leaf_clade) + + logger.info(f"Tree built with {len(list(tree.find_clades()))} nodes") + return tree + + +def fix_root_unifurcation(tree: Tree) -> None: + """ + Fix the case where the root has only one child by making that child the new root. + + :param tree: BioPython Tree object + """ + if len(tree.root.clades) == 1: + logger.info("Root has only one child - fixing root unifurcation") + # The child of the root becomes the new root + old_root = tree.root + new_root = old_root.clades[0] + + # Set the new root's properties + tree.root = new_root + + # If we want to preserve the old root's name in some way + if old_root.name and old_root.name != "root": + # We could add the old root's name to the new root if needed + if new_root.name: + new_root.name = f"{old_root.name}_{new_root.name}" + else: + new_root.name = old_root.name + + logger.info("Root unifurcation fixed") + + +def collapse_unbranching_nodes(tree: Tree) -> None: + """ + Remove internal nodes that have only one child using recursion. + + :param tree: BioPython Tree object + """ + logger.info("Collapsing unbranching internal nodes") + + # Start with the root node + original_node_count = len(list(tree.find_clades())) + _collapse_clade_recursively(tree.root) + + new_node_count = len(list(tree.find_clades())) + logger.info(f"Tree after collapsing: {new_node_count} nodes (removed {original_node_count - new_node_count} nodes)") + + +def _collapse_clade_recursively(clade: Clade) -> bool: + """ + Recursively process a clade and its children to collapse unbranching nodes. + Returns True if this clade should be removed. + + :param clade: BioPython Clade object + :return: Whether this clade should be removed + """ + # Process all children first (bottom-up approach) + i = 0 + while i < len(clade.clades): + child = clade.clades[i] + should_remove = _collapse_clade_recursively(child) + + if should_remove: + # Replace the child with its own children + clade.clades.pop(i) + clade.clades[i:i] = child.clades + # Don't increment i since we need to process the new children + else: + i += 1 + + # A clade should be collapsed if it has exactly one child + # We don't collapse the root, even if it has one child + return len(clade.clades) == 1 and clade.name != "root" + + +def write_tree_to_newick_old(tree: Tree, output_file: str) -> None: + """ + Write the tree to a Newick file with interior labels. + + :param tree: BioPython Tree object + :param output_file: Path to the output Newick file + """ + logger.info(f"Writing tree to {output_file}") + + # Remove branch lengths from the tree: they're all 0.00000 and meaningless + for clade in tree.find_clades(): + clade.branch_length = None + + try: + Phylo.write(tree, output_file, "newick", branch_length_only=False) + logger.info("Tree successfully written to Newick file") + except Exception as e: + logger.error(f"Error writing tree to file: {e}") + sys.exit(1) + + +def write_tree_to_newick(tree: Tree, output_file: str, include_internal_labels: bool = False) -> None: + """ + Write the tree to a Newick file with optional interior labels and no branch lengths. + + :param tree: BioPython Tree object + :param output_file: Path to the output Newick file + :param include_internal_labels: Whether to include labels for internal nodes + """ + logger.info(f"Writing tree to {output_file}") + + # Create a copy of the tree to avoid modifying the original + import copy + tree_copy = copy.deepcopy(tree) + + # Set all branch lengths to None + for clade in tree_copy.find_clades(): + clade.branch_length = None + + # Optionally remove internal node labels + if not include_internal_labels and clade.clades: # If it's an internal node + clade.name = "" + + try: + # First write to a temporary file + temp_file = f"{output_file}.tmp" + Phylo.write(tree_copy, temp_file, "newick") + + # Read the file and remove branch lengths + with open(temp_file, 'r') as f: + newick_str = f.read().strip() + + # Remove all branch lengths (patterns like :0.0 or :0.000) + import re + cleaned_newick = re.sub(r':[0-9.]+', '', newick_str) + + # Write the cleaned Newick string to the output file + with open(output_file, 'w') as f: + f.write(cleaned_newick) + + # Clean up temporary file + os.remove(temp_file) + + logger.info("Tree successfully written to Newick file") + except Exception as e: + logger.error(f"Error writing tree to file: {e}") + sys.exit(1) + +def remove_internal_labels(tree: Tree) -> None: + """ + Remove labels from all internal nodes, keeping only leaf labels. + + :param tree: BioPython Tree object + """ + logger.info("Removing internal node labels") + + for clade in tree.find_clades(): + # If this is an internal node (has children), remove its name + if clade.clades: + clade.name = "" + + logger.info("Internal node labels removed") + + +if __name__ == "__main__": + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Build a taxonomic tree from BOLD process IDs in a Phylip file.") + parser.add_argument("-f", "--fasta", required=True, help="Aligned FASTA file with process IDs") + parser.add_argument("-b", "--bold", required=True, help="BOLD BCDM TSV file") + parser.add_argument("-o", "--output", required=True, help="Output Newick tree file") + parser.add_argument("-c", "--collapse", action="store_true", help="Collapse unbranching internal nodes") + parser.add_argument("-n", "--nodelabels", action="store_true", help="Provide internal node labels") + parser.add_argument("-v", "--verbose", action="store_true", help="Increase output verbosity") + args = parser.parse_args() + + # Check if files exist + if not os.path.isfile(args.fasta): + parser.error(f"FASTA file not found: {args.fasta}") + if not os.path.isfile(args.bold): + parser.error(f"BOLD TSV file not found: {args.bold}") + + # Set logging level based on verbosity + if args.verbose: + logger.setLevel(logging.DEBUG) + + # Extract process IDs from FASTA file + process_ids = extract_process_ids(args.fasta) + + # Read BOLD taxonomy for these process IDs + taxonomy_df = read_bold_taxonomy(args.bold, process_ids) + + # Build taxonomy paths + taxonomy_paths = build_taxonomy_paths(taxonomy_df) + + # Build the tree + tree = build_tree_from_paths(taxonomy_paths) + + # Optionally collapse unbranching internal nodes + if args.collapse: + collapse_unbranching_nodes(tree) + # After collapsing, fix any root unifurcation + fix_root_unifurcation(tree) + + # Write the tree to a Newick file + write_tree_to_newick(tree, args.output, args.nodelabels) + + logger.info("Process completed successfully")