Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/deep_taxon/gtdb/make_fof.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import deep_taxon.sequence.dna_table
from hdmf.common import get_hdf5io

def get_taxa_id(path):
c, n = os.path.basename(path).split('_')[0:2]
Expand All @@ -24,6 +26,8 @@ def make_fof(argv=None):
parser.add_argument('accessions', type=str, help='A file containing accessions')
parser.add_argument('-t', '--tree', action='store_true', default=False, help='accessions are from a tree in Newick format')
parser.add_argument('-f', '--hdmf_file', action='store_true', default=False, help='get accessions from a DeepIndex HDMF file')
parser.add_argument('-A', '--append', action='store_true', default=False, help='include accessions')
parser.add_argument('-T', '--taxonomy', action='store_true', default=False, help='include taxonomy')
grp = parser.add_mutually_exclusive_group()
grp.add_argument('-P', '--protein', action='store_true', default=False, help='get paths for protein files')
grp.add_argument('-C', '--cds', action='store_true', default=False, help='get paths for CDS files')
Expand Down Expand Up @@ -75,10 +79,23 @@ def make_fof(argv=None):

else:
accessions = None
taxonomy = False
if args.hdmf_file:
with h5py.File(args.accessions, 'r') as f:
accessions = f['genome_table']['taxon_id'][:]
with get_hdf5io(args.accessions, 'r') as io: #h5py.File(args.accessions, 'r') as f:
difile = io.read()
for c in difile.taxa_table.columns:
c.transform(lambda x: x[:])
for c in difile.genome_table.columns:
c.transform(lambda x: x[:])
tdf = difile.taxa_table.to_dataframe()
gdf = difile.genome_table.to_dataframe(index=True)
tdf.iloc[gdf['rep_idx']]
accessions = gdf['taxon_id']

else:
if args.taxonomy:
print("-T is only valid with -f", file=sys.stderr)
exit()
if os.path.exists(args.accessions):
with open(args.accessions, 'r') as f:
accessions = f.readlines()
Expand All @@ -90,8 +107,14 @@ def make_fof(argv=None):
func = get_fna_path
elif args.protein:
func = get_faa_path
for line in accessions:
print(func(line.strip(), args.fadir), file=sys.stdout)
for idx, line in enumerate(accessions.astype('U')):
path = func(line.strip(), args.fadir)
if args.taxonomy:
tax = "\t".join(tdf.iloc[idx])
path = f'{tax}\t{path}'
elif args.append:
path = f'{line}\t{path}'
print(path, file=sys.stdout)


if __name__ == '__main__':
Expand Down
148 changes: 139 additions & 9 deletions src/deep_taxon/nn/infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import warnings
import numpy as np
import os
from time import time
Expand All @@ -10,6 +11,7 @@

from .lsf_environment import LSFEnvironment
from pytorch_lightning.plugins.environments import SLURMEnvironment
from hdmf.common import get_hdf5io


import glob
Expand All @@ -36,7 +38,7 @@ def parse_args(*addl_args, argv=None):
parser.add_argument('checkpoint', type=str, help='the checkpoint file to use for running inference')
parser.add_argument('-o', '--output', type=str, help='the file to save outputs to', default=None)
parser.add_argument('-f', '--resnet_features', action='store_true', help='drop classifier from ResNet model before inference', default=False)
parser.add_argument('-F', '--features', action='store_true', help='outputs are features i.e. do not softmax and compute predictions', default=False)
parser.add_argument('-F', '--features', action='store_true', help='do not softmax model output', default=False)
parser.add_argument('-g', '--gpus', nargs='?', const=True, default=False, help='use GPU')
parser.add_argument('-d', '--debug', action='store_true', default=False,
help='run in debug mode i.e. only run two batches')
Expand All @@ -52,6 +54,7 @@ def parse_args(*addl_args, argv=None):
parser.add_argument('-p', '--maxprob', metavar='TOPN', nargs='?', const=1, default=0, type=int,
help='store the top TOPN probablities of each output. By default, TOPN=1')
parser.add_argument('-c', '--save_chunks', action='store_true', help='do store network outputs for each chunk', default=False)
parser.add_argument('-O', '--save_outputs', action='store_true', help='store full network outputs for each sequence. otherwise only store top-1 probability', default=False)

env_grp = parser.add_argument_group("Resource Manager").add_mutually_exclusive_group()
env_grp.add_argument("--lsf", default=False, action='store_true', help='running in an LSF environment')
Expand Down Expand Up @@ -149,12 +152,19 @@ def process_args(args, comm=None):
model = ResNetFeatures(model)
args.features = True

io = get_hdf5io(args.input, 'r')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
difile = io.read()

difile.set_label_key(args.tgt_tax_lvl)

if size > 1:
dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True, comm=comm, size=size, rank=rank)
dataset = LazySeqDataset(difile=difile, hparams=argparse.Namespace(**model.hparams), keep_open=True, comm=comm, size=size, rank=rank)
else:
dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True)
dataset = LazySeqDataset(difile=difile, hparams=argparse.Namespace(**model.hparams), keep_open=True)

tot_bases = dataset.orig_difile.get_seq_lengths().sum()
tot_bases = dataset.difile.get_seq_lengths().sum()
args.logger.info(f'rank {rank} - processing {tot_bases} bases across {len(dataset)} samples')

tmp_dset = dataset
Expand All @@ -163,7 +173,7 @@ def process_args(args, comm=None):
if args.num_workers > 0:
kwargs['num_workers'] = args.num_workers
kwargs['multiprocessing_context'] = 'spawn'
kwargs['worker_init_fn'] = dataset.worker_init
#kwargs['worker_init_fn'] = dataset.worker_init
kwargs['persistent_workers'] = True
loader = get_loader(tmp_dset, inference=True, **kwargs)

Expand Down Expand Up @@ -212,11 +222,11 @@ def run_inference(argv=None):
parallel_chunked_inf_summ(model, dataset, loader, args, f_kwargs)


def parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs):
def old_parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs):

n_samples = len(dataset.orig_difile.seq_table)
all_seq_ids = dataset.orig_difile.get_sequence_subset()
seq_lengths = dataset.orig_difile.get_seq_lengths()
n_samples = len(dataset.difile.seq_table)
all_seq_ids = dataset.difile.get_sequence_subset()
seq_lengths = dataset.difile.get_seq_lengths()

f = h5py.File(args.output, 'w', **fkwargs)
outputs_dset = None
Expand Down Expand Up @@ -394,6 +404,126 @@ def cat(indices, outputs, labels, orig_lens, seq_ids):
return ret


def parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs):


from hdmf_ml import ResultsTable
from hdmf.backends.hdf5 import H5DataIO

results = ResultsTable()

n_samples = dataset.difile.n_seqs


if args.save_chunks:
print("I will not save chunks, sorry")

prob_dset = results.add_topk_probabilities(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=float, fillvalue=0.0))
label_dset = results.add_topk_classes(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=int, fillvalue=-1))
target_dset = results.add_predicted_class(data=H5DataIO(shape=(n_samples,), dtype=int, fillvalue=-1))

outputs_dset = None
if args.save_outputs:
outputs_dset = results.add_predicted_probability(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=float, fillvalue=0.0))

io = get_hdf5io(args.output, 'w')
io.write(results)

check = lambda dset: dset.data.dataset if dset is not None else None

prob_dset = check(prob_dset)
label_dset = check(label_dset)
target_dset = check(target_dset)

# to-write queues - we use those so we're not doing I/O at every iteration
outputs_q = list()
targets_q = list()
counts_q = list()
seqs_q = list()

# write what's in the to-write queues
if not hasattr(args, 'n_seqs'):
args.n_seqs = 500

# send model to GPU
model.to(args.device)

uniq_labels = set()
for idx, _outputs, _labels, _orig_lens, _seq_ids in get_outputs(model, loader, args.device, debug=args.debug, chunks=args.n_batches, prog_bar=dataset.rank==0):

seqs, counts = np.unique(_seq_ids, return_counts=True)
outputs_sum = list()
labels = list()
for i in seqs:
mask = _seq_ids == i
outputs_sum.append(_outputs[mask].sum(axis=0))
labels.append(_labels[mask][0])
uniq_labels.update(_labels)

# Add the first sum of this iteration to the last sum of the
# previous iteration if they belong to the same sequence
if len(seqs_q) > 0 and seqs_q[-1] == seqs[0]:
outputs_q[-1] += outputs_sum[0]
counts_q[-1] += counts[0]
# drop the first sum so we don't end up with duplicates
seqs = seqs[1:]
counts = counts[1:]
outputs_sum = outputs_sum[1:]
labels = labels[1:]

outputs_q.extend(outputs_sum)
seqs_q.extend(seqs)
counts_q.extend(counts)
targets_q.extend(labels)

# write when we get above a certain number of sequences
if len(outputs_q) > args.n_seqs:
idx = seqs_q[:-1]
args.logger.debug(f"rank {dataset.rank} - saving these sequences {idx}")
# compute mean from sums
for i in range(len(idx)):
outputs_q[i] /= counts_q[i]

if outputs_dset is not None:
outputs_dset[idx] = outputs_q[:-1]
target_dset[idx] = targets_q[:-1]

prob_dset[idx], label_dset[idx] = get_topk(args.maxprob, outputs_q[:-1])

outputs_q = outputs_q[-1:]
seqs_q = seqs_q[-1:]
counts_q = counts_q[-1:]
targets_q = targets_q[-1:]

args.logger.debug(f"rank {dataset.rank} - saving these sequences {seqs_q}")
args.logger.debug(f"rank {dataset.rank} - came across these labels {list(sorted(uniq_labels))}")

# clean up what's left in the to-write queue
for i in range(len(seqs_q)):
outputs_q[i] /= counts_q[i]

if outputs_dset is not None:
outputs_dset[seqs_q] = outputs_q

target_dset[seqs_q] = targets_q
prob_dset[seqs_q], label_dset[seqs_q] = get_topk(args.maxprob, outputs_q)

args.logger.info(f'rank {dataset.rank} - closing {args.output}')

io.close()


def get_topk(topk, outputs):
k = outputs[0].shape[0] - topk
maxprobs = list()
maxclses = list()
for i in range(len(outputs)):
topk_idx = np.argpartition(outputs[i], k)[k:]
topk_idx = topk_idx[np.argsort(outputs[i][topk_idx])][::-1]
maxprobs.append(outputs[i][topk_idx])
maxclses.append(topk_idx)
return maxprobs, maxclses

from . import models # noqa: E402

if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions src/deep_taxon/nn/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def dataset_stats(argv=None):
if comm is not None:
kwargs['comm'] = comm
before = time()
dataset = LazySeqDataset(**kwargs)
dataset.load(sequence=False)
dataset = LazySeqDataset(load=False, **kwargs)
after = time()

io.close()
Expand Down
2 changes: 1 addition & 1 deletion src/deep_taxon/sequence/dna_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def get_label_classes(self):
return self._classes

def get_seq_lengths(self):
self.lengths.copy()
return self.lengths.copy()

def get_counts(self, orig=False):
"""
Expand Down