diff --git a/src/exabiome/gtdb/prepare_data.py b/src/exabiome/gtdb/prepare_data.py index d3db8b2..bc993b8 100644 --- a/src/exabiome/gtdb/prepare_data.py +++ b/src/exabiome/gtdb/prepare_data.py @@ -87,11 +87,13 @@ def prepare_data(argv=None): parser = argparse.ArgumentParser() parser.add_argument('accessions', type=str, help='file of the NCBI accessions of the genomes to convert') + parser.add_argument('fadir', type=str, help='directory with NCBI sequence files') parser.add_argument('metadata', type=str, help='metadata file from GTDB') parser.add_argument('tree', type=str, help='the distances file') parser.add_argument('out', type=str, help='output HDF5') grp = parser.add_mutually_exclusive_group() + parser.add_argument('--locus_tags', type=str, help='file of the NCBI locus tags of the genes to convert', default=None) parser.add_argument('-e', '--emb', type=str, help='embedding file', default=None) grp.add_argument('-P', '--protein', action='store_true', default=False, help='get paths for protein files') grp.add_argument('-C', '--cds', action='store_true', default=False, help='get paths for CDS files') @@ -118,6 +120,11 @@ def prepare_data(argv=None): with open(args.accessions, 'r') as f: taxa_ids = [l[:-1] for l in f.readlines()] + # read locus tags + logger.info('reading locus tags %s' % args.locus_tags) + with open(args.locus_tags, 'r') as f: + locus_ids = set([l[:-1] for l in f.readlines()]) + # get paths to Fasta Files fa_path_func = get_genomic_path if args.cds: @@ -214,7 +221,7 @@ def func(row): SeqTable = DNATable if args.cds: logger.info("reading and writing CDS sequences") - seqit = DNAVocabGeneIterator(fapaths, logger=logger, min_seq_len=args.min_len) + seqit = DNAVocabGeneIterator(fapaths, locus_ids=locus_ids, logger=logger, min_seq_len=args.min_len) else: seqit = DNAVocabIterator(fapaths, logger=logger, min_seq_len=args.min_len) else: diff --git a/src/exabiome/sequence/convert.py b/src/exabiome/sequence/convert.py index 56463df..09c936d 100644 --- a/src/exabiome/sequence/convert.py +++ b/src/exabiome/sequence/convert.py @@ -362,7 +362,9 @@ class DNAVocabIterator(VocabIterator): chars, basemap = _get_DNA_map() class DNAVocabGeneIterator(VocabIterator): - + def __init__(self, paths, locus_ids=None, logger=None, min_seq_len=None): + super().__init__(paths, logger=logger, min_seq_len=min_seq_len) + self.locus_ids = locus_ids chars, basemap = _get_DNA_map() def _read_seq(self, path): @@ -370,7 +372,8 @@ def _read_seq(self, path): kwargs = {'format': 'fasta', 'constructor': self.skbio_cls, 'validate': False} for seq in skbio.io.read(path, **kwargs): ltag = re.findall("\[locus_tag=([^[\]]*)\]", seq.metadata['description'])[0] - yield seq, ltag + if ltag in self.locus_ids: + yield seq, ltag def _get_AA_map(): chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'