Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conda_env_stag.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: stag_0.8.1_tax3
name: stag
channels:
- bioconda
- defaults
Expand All @@ -11,7 +11,7 @@ dependencies:
- easel
- numpy
- pandas
- scikit-learn<0.24
- scikit-learn
- h5py
- seqtk
- regex
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import re
import sys

from stag import __version__ as stag_version

here = path.abspath(path.dirname("__file__"))

with open(path.join(here, "DESCRIPTION.md"), encoding="utf-8") as description:
description = long_description = description.read()

name="stag"
version = [line.strip().split(" ")[-1] for line in open("stag/__init__.py") if line.startswith("__version__")][0]
version = stag_version

if sys.version_info.major != 3:
raise EnvironmentError("""{toolname} is a python module that requires python3, and is not compatible with python2.""".format(toolname=name))
Expand Down
4 changes: 2 additions & 2 deletions stag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = "0.8.2"
__version__ = "0.9"
__title__ = "stag"
__author__ = "Alessio Milanese"
__license__ = 'GPLv3+'
__copyright__ = 'Copyright 2019-2021 Alessio Milanese' # EMBL?
__copyright__ = 'Copyright 2019-2022 Alessio Milanese'
32 changes: 20 additions & 12 deletions stag/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def print_menu_create_db():
sys.stderr.write(f" {bco.LightBlue}-p{bco.ResetAll} FILE protein sequences, if they were used for the alignment {bco.LightMagenta}[None]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-e{bco.ResetAll} STR penalty for the logistic regression {bco.LightMagenta}[\"l1\"]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-E{bco.ResetAll} STR solver for the logistic regression {bco.LightMagenta}[\"liblinear\"]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-N{bco.ResetAll} INT solver_iterations: increase this parameter if the output displays a ConvergenceWarning. (default=5000)\n")
sys.stderr.write(f" {bco.LightBlue}-t{bco.ResetAll} INT number of threads {bco.LightMagenta}[1]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-v{bco.ResetAll} INT verbose level: 1=error, 2=warning, 3=message, 4+=debugging {bco.LightMagenta}[3]{bco.ResetAll}\n\n")
# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -142,9 +143,10 @@ def print_menu_train():
sys.stderr.write(f" {bco.LightBlue}-C{bco.ResetAll} FILE save intermediate cross validation results {bco.LightMagenta}[None]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-t{bco.ResetAll} INT number of threads {bco.LightMagenta}[1]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-m{bco.ResetAll} INT threshold for the number of features per sequence (percentage) {bco.LightMagenta}[0]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-v{bco.ResetAll} INT verbose level: 1=error, 2=warning, 3=message, 4+=debugging {bco.LightMagenta}[3]{bco.ResetAll}\n\n")
sys.stderr.write(f" {bco.LightBlue}-v{bco.ResetAll} INT verbose level: 1=error, 2=warning, 3=message, 4+=debugging {bco.LightMagenta}[3]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-e{bco.ResetAll} STR penalty for the logistic regression {bco.LightMagenta}[\"l1\"]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-E{bco.ResetAll} STR solver for the logistic regression {bco.LightMagenta}[\"liblinear\"]{bco.ResetAll}\n\n")
sys.stderr.write(f" {bco.LightBlue}-E{bco.ResetAll} STR solver for the logistic regression {bco.LightMagenta}[\"liblinear\"]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-N{bco.ResetAll} INT solver_iterations: increase this parameter if the output displays a ConvergenceWarning. (default=5000)\n\n")
sys.stderr.write(f"{bco.Cyan}Note:{bco.ResetAll} if -p is provided, then the alignment will be done at the level\nof the proteins and then converted to gene alignment (from -i input).\nThe order of the sequences in -i and -p should be the same.\n\n")
# ------------------------------------------------------------------------------
def print_menu_correct_seq():
Expand All @@ -166,7 +168,8 @@ def print_menu_train_genome():
sys.stderr.write(f" {bco.LightBlue}-C{bco.ResetAll} FILE stag database for the concatenated genes{bco.LightMagenta}[required]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-o{bco.ResetAll} FILE output file name (HDF5 format) {bco.LightMagenta}[required]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-t{bco.ResetAll} INT number of threads {bco.LightMagenta}[1]{bco.ResetAll}\n")
sys.stderr.write(f" {bco.LightBlue}-v{bco.ResetAll} INT verbose level: 1=error, 2=warning, 3=message, 4+=debugging {bco.LightMagenta}[3]{bco.ResetAll}\n\n")
sys.stderr.write(f" {bco.LightBlue}-v{bco.ResetAll} INT verbose level: 1=error, 2=warning, 3=message, 4+=debugging {bco.LightMagenta}[3]{bco.ResetAll}")
sys.stderr.write(f" {bco.LightBlue}-N{bco.ResetAll} INT solver_iterations: increase this parameter if the output displays a ConvergenceWarning. (default=5000)\n\n")
# ------------------------------------------------------------------------------
def print_menu_classify_genome():
sys.stderr.write("\n")
Expand Down Expand Up @@ -230,6 +233,7 @@ def main(argv=None):
parser.add_argument('-e', action="store", default="l1", dest='penalty_logistic', help='penalty for the logistic regression',choices=['l1','l2','none'])
parser.add_argument('-E', action="store", default="liblinear", dest='solver_logistic', help='solver for the logistic regression',choices=['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'])
parser.add_argument('-G', action="store", dest="marker_genes", default=None, help="Set of identified marker genes in lieu of a genomic sequence")
parser.add_argument('-N', action="store", dest="solver_iterations", default=5000, type=int, help="Increase this parameter if the output displays a ConvergenceWarning.")

parser.add_argument('--version', action='version', version='%(prog)s {0} on python {1}'.format(tool_version, sys.version.split()[0]))

Expand Down Expand Up @@ -323,7 +327,8 @@ def main(argv=None):
# call the function to create the database
create_db.create_db(args.aligned_sequences, args.taxonomy, args.verbose, args.output, args.use_cm_align,
args.template_al, args.intermediate_cross_val, args.protein_fasta_input,
args.penalty_logistic, args.solver_logistic, procs=args.threads)
args.penalty_logistic, args.solver_logistic, max_iter=args.solver_iterations,
procs=args.threads)

# --------------------------------------------------------------------------
# TRAIN routine
Expand Down Expand Up @@ -369,7 +374,8 @@ def main(argv=None):
# call the function to create the database
create_db.create_db(al_file.name, args.taxonomy, args.verbose, args.output, args.use_cm_align,
args.template_al, args.intermediate_cross_val, args.protein_fasta_input,
args.penalty_logistic, args.solver_logistic, procs=args.threads)
args.penalty_logistic, args.solver_logistic, max_iter=args.solver_iterations,
procs=args.threads)

# what to do with intermediate alignment -------------------------------
if not args.intermediate_al:
Expand Down Expand Up @@ -550,13 +556,15 @@ def main(argv=None):
else:
for f in os.listdir(args.dir_input):
f = os.path.join(args.dir_input, f)
try:
if os.path.isfile(f) and open(f).read(1)[0] == ">":
list_files.append(f)
except Exception as e:
if args.verbose > 1:
sys.stderr.write("[W::main] Warning: ")
sys.stderr.write("Cannot open file: {}\n".format(f))
if os.path.isfile(f):
try:
with open(f) as _in:
if _in.read(1)[0] == ">":
list_files.append(f)
except Exception as e:
if args.verbose > 1:
sys.stderr.write("[W::main] Warning: ")
sys.stderr.write("Cannot open file: {}\n".format(f))

if not list_files:
handle_error("no fasta files found in the directory.", None)
Expand Down
2 changes: 1 addition & 1 deletion stag/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def align_generator(seq_file, protein_file, hmm_file, use_cmalign, n_threads, ve

if protein_file:
seq_stream = zip(read_fasta(parse_cmd.stdout, head_start=1),
read_fasta(open(seq_file), is_binary=False, head_start=1))
read_fasta(seq_file, is_binary=False, head_start=1))
else:
seq_stream = read_fasta(parse_cmd.stdout, head_start=1)

Expand Down
8 changes: 6 additions & 2 deletions stag/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ def classify_seq(gene_id, test_seq, taxonomy, tax_function, classifiers, threads

prob_per_level = "/".join(prob_per_level)
perc_text = "/".join([str(p) for p in perc])
assigned_tax_text = ";".join(tax[0:(int(selected_level) + 1)])
tax_text = "/".join(tax)

tax_str = [t.decode() for t in tax if isinstance(t, bytes)]
assigned_tax_text = ";".join(tax_str[0:(int(selected_level) + 1)])

# assigned_tax_text = ";".join(tax[0:(int(selected_level) + 1)])
tax_text = "/".join(tax_str)

result = [gene_id, assigned_tax_text, tax_text, selected_level,
perc_text, prob_per_level, str(n_aligned_characters)]
Expand Down
152 changes: 80 additions & 72 deletions stag/classify_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,37 @@ def run_prodigal(genome):
# we need two files, one for the proteins and one for the genes
genes = tempfile.NamedTemporaryFile(delete=False, mode="w")
proteins = tempfile.NamedTemporaryFile(delete=False, mode="w")
# prodigal command
prodigal_cmd = "prodigal -i {genome} -d {gene_file} -a {protein_file}".format(
genome=genome, gene_file=genes.name, protein_file=proteins.name
)
cmd = shlex.split(prodigal_cmd)
parse_cmd = subprocess.Popen(cmd, stdout=DEVNULL,stderr=subprocess.PIPE)
# we save stderr if necessary
all_stderr = ""
for line in parse_cmd.stderr:
line = line.decode('ascii')
all_stderr = all_stderr + line
return_code = parse_cmd.wait()
if return_code:
raise ValueError(f"[E::align] Error. prodigal failed\n\n{all_stderr}")

# we re-name the header of the fasta files ---------------------------------
# we expect to have the same number of genes and proteins, and also that the
def copy_fasta(fasta_file, seqid, is_binary=True, head_start=0):
with tempfile.NamedTemporaryFile(delete=False, mode="w") as fasta_out, open(fasta_file) as fasta_in:
for index, (sid, seq) in enumerate(read_fasta(fasta_in, is_binary=is_binary), start=1):
print(">{seqid}_{index}".format(**locals()), seq, sep="\n", file=fasta_out)
fasta_out.flush()
os.fsync(fasta_out.fileno())
return fasta_out.name, index

parsed_genes, gene_count = copy_fasta(genes.name, genome, is_binary=False)
parsed_proteins, protein_count = copy_fasta(proteins.name, genome, is_binary=False)

with genes, proteins:
# prodigal command
prodigal_cmd = "prodigal -i {genome} -d {gene_file} -a {protein_file}".format(
genome=genome, gene_file=genes.name, protein_file=proteins.name
)
cmd = shlex.split(prodigal_cmd)
parse_cmd = subprocess.Popen(cmd, stdout=DEVNULL,stderr=subprocess.PIPE)

with parse_cmd:
# we save stderr if necessary
all_stderr = ""
for line in parse_cmd.stderr:
line = line.decode('ascii')
all_stderr = all_stderr + line
return_code = parse_cmd.wait()
if return_code:
raise ValueError(f"[E::align] Error. prodigal failed\n\n{all_stderr}")

# we re-name the header of the fasta files ---------------------------------
# we expect to have the same number of genes and proteins, and also that the
def copy_fasta(fasta_file, seqid, is_binary=True, head_start=0):
with tempfile.NamedTemporaryFile(delete=False, mode="w") as fasta_out, open(fasta_file) as fasta_in:
for index, (sid, seq) in enumerate(read_fasta(fasta_in, is_binary=is_binary), start=1):
print(">{seqid}_{index}".format(**locals()), seq, sep="\n", file=fasta_out)
fasta_out.flush()
os.fsync(fasta_out.fileno())
return fasta_out.name, index

parsed_genes, gene_count = copy_fasta(genes.name, genome, is_binary=False)
parsed_proteins, protein_count = copy_fasta(proteins.name, genome, is_binary=False)

os.remove(genes.name)
os.remove(proteins.name)
Expand All @@ -103,36 +107,39 @@ def extract_gene_from_one_genome(file_to_align, hmm_file, gene_threshold,mg_name
# INFO: genes_path, proteins_path [where to save the result]
# we run hmmsearch
temp_hmm = tempfile.NamedTemporaryFile(delete=False, mode="w")
hmm_cmd = "hmmsearch --tblout "+temp_hmm.name+" "+hmm_file+" "+file_to_align

CMD = shlex.split(hmm_cmd)
hmm_CMD = subprocess.Popen(CMD, stdout=DEVNULL,stderr=subprocess.PIPE)
# we save stderr if necessary
all_stderr = ""
for line in hmm_CMD.stderr:
line = line.decode('ascii')
all_stderr = all_stderr + line
return_code = hmm_CMD.wait()
if return_code:
raise ValueError(f"[E::align] Error. hmmsearch failed\n\nMG: {mg_name}\nCALL: {hmm_cmd}\n\n{all_stderr}")

# in temp_hmm.name there is the result from hmm ----------------------------
# we select which genes/proteins we need to extract from the fasta files
# produced by prodigal
sel_genes = dict()
o = open(temp_hmm.name,"r")
for line in o:
if not line.startswith("#"):
vals = re.sub(" +"," ",line.rstrip()).split(" ")
gene_id = vals[0]
e_val = vals[4]
score = vals[5]
if float(score) > float(gene_threshold):
sel_genes[gene_id] = score
o.close()
with temp_hmm:

hmm_cmd = f"hmmsearch --tblout {temp_hmm.name} {hmm_file} {file_to_align}"
CMD = shlex.split(hmm_cmd)
hmm_CMD = subprocess.Popen(CMD, stdout=DEVNULL,stderr=subprocess.PIPE)

with hmm_CMD:
# we save stderr if necessary
all_stderr = ""
for line in hmm_CMD.stderr:
line = line.decode('ascii')
all_stderr = all_stderr + line
return_code = hmm_CMD.wait()
if return_code:
raise ValueError(f"[E::align] Error. hmmsearch failed\n\nMG: {mg_name}\nCALL: {hmm_cmd}\n\n{all_stderr}")

# in temp_hmm.name there is the result from hmm ----------------------------
# we select which genes/proteins we need to extract from the fasta files
# produced by prodigal
sel_genes = {}
with open(temp_hmm.name, "r") as _in:
for line in _in:
if not line.startswith("#"):
vals = re.sub(" +"," ",line.rstrip()).split(" ")
gene_id = vals[0]
e_val = vals[4]
score = vals[5]
if float(score) > float(gene_threshold):
sel_genes[gene_id] = score

# remove file with the result from the hmm
if os.path.isfile(temp_hmm.name): os.remove(temp_hmm.name)
if os.path.isfile(temp_hmm.name):
os.remove(temp_hmm.name)

return sel_genes

Expand All @@ -153,7 +160,7 @@ def extract_genes(mg_name, hmm_file, use_protein_file, genomes_pred, gene_thresh
else:
file_to_align = genomes_pred[g][0]
# call function that uses hmmsearch
all_genes_raw[g][mg_name] = extract_gene_from_one_genome(file_to_align, hmm_file, gene_threshold,mg_name)
all_genes_raw[g][mg_name] = extract_gene_from_one_genome(file_to_align, hmm_file, gene_threshold, mg_name)

def select_genes(all_genes_raw, keep_all_genes):
return_dict = dict()
Expand Down Expand Up @@ -261,7 +268,7 @@ def fetch_MGs(database_files, database_path, genomes_pred, keep_all_genes, gene_
# for each MG, we extract the hmm and if using proteins or not ---------
path_mg = os.path.join(database_path, mg)

with h5py.File(path_mg, 'r') as db_in, tempfile.NamedTemporaryFile(delete=False, mode="w") as hmm_file:
with h5py.File(path_mg, 'r') as db_in, tempfile.NamedTemporaryFile(delete=False, mode="wb") as hmm_file:
os.chmod(hmm_file.name, 0o644)
hmm_file.write(db_in['hmm_file'][0])
hmm_file.flush()
Expand Down Expand Up @@ -320,23 +327,24 @@ def annotate_MGs(MGS, database_files, database_base_path, dir_ali, procs=2):
if not os.path.isfile(db):
raise ValueError(f"Error: file for gene database {db} is missing")

pool = mp.Pool(processes=procs)

results = (
pool.apply_async(
classify,
args=(os.path.join(database_base_path,mg),),
kwds={"fasta_input": fna, "protein_fasta_input": faa,
"save_ali_to_file": os.path.join(dir_ali, mg),
"internal_call": True}
d = {}
with mp.Pool(processes=procs) as pool:

results = (
pool.apply_async(
classify,
args=(os.path.join(database_base_path,mg),),
kwds={"fasta_input": fna, "protein_fasta_input": faa,
"save_ali_to_file": os.path.join(dir_ali, mg),
"internal_call": True}
)
for mg, (fna, faa) in found_marker_genes.items()
)
for mg, (fna, faa) in found_marker_genes.items()
)

d = dict()
for p in results:
_, predictions = p.get()
d.update(predictions)
for p in results:
_, predictions = p.get()
d.update(predictions)

return d
#return dict(p.get()[1] for p in results)

Expand Down
Loading