diff --git a/src/deep_taxon/gtdb/make_fof.py b/src/deep_taxon/gtdb/make_fof.py index d214a7f..031a84d 100644 --- a/src/deep_taxon/gtdb/make_fof.py +++ b/src/deep_taxon/gtdb/make_fof.py @@ -1,4 +1,6 @@ import os +import deep_taxon.sequence.dna_table +from hdmf.common import get_hdf5io def get_taxa_id(path): c, n = os.path.basename(path).split('_')[0:2] @@ -24,6 +26,8 @@ def make_fof(argv=None): parser.add_argument('accessions', type=str, help='A file containing accessions') parser.add_argument('-t', '--tree', action='store_true', default=False, help='accessions are from a tree in Newick format') parser.add_argument('-f', '--hdmf_file', action='store_true', default=False, help='get accessions from a DeepIndex HDMF file') + parser.add_argument('-A', '--append', action='store_true', default=False, help='include accessions') + parser.add_argument('-T', '--taxonomy', action='store_true', default=False, help='include taxonomy') grp = parser.add_mutually_exclusive_group() grp.add_argument('-P', '--protein', action='store_true', default=False, help='get paths for protein files') grp.add_argument('-C', '--cds', action='store_true', default=False, help='get paths for CDS files') @@ -75,10 +79,23 @@ def make_fof(argv=None): else: accessions = None + taxonomy = False if args.hdmf_file: - with h5py.File(args.accessions, 'r') as f: - accessions = f['genome_table']['taxon_id'][:] + with get_hdf5io(args.accessions, 'r') as io: #h5py.File(args.accessions, 'r') as f: + difile = io.read() + for c in difile.taxa_table.columns: + c.transform(lambda x: x[:]) + for c in difile.genome_table.columns: + c.transform(lambda x: x[:]) + tdf = difile.taxa_table.to_dataframe() + gdf = difile.genome_table.to_dataframe(index=True) + tdf.iloc[gdf['rep_idx']] + accessions = gdf['taxon_id'] + else: + if args.taxonomy: + print("-T is only valid with -f", file=sys.stderr) + exit() if os.path.exists(args.accessions): with open(args.accessions, 'r') as f: accessions = f.readlines() @@ -90,8 +107,14 @@ def make_fof(argv=None): func = get_fna_path elif args.protein: func = get_faa_path - for line in accessions: - print(func(line.strip(), args.fadir), file=sys.stdout) + for idx, line in enumerate(accessions.astype('U')): + path = func(line.strip(), args.fadir) + if args.taxonomy: + tax = "\t".join(tdf.iloc[idx]) + path = f'{tax}\t{path}' + elif args.append: + path = f'{line}\t{path}' + print(path, file=sys.stdout) if __name__ == '__main__': diff --git a/src/deep_taxon/nn/infer.py b/src/deep_taxon/nn/infer.py index 9cd80ed..607e668 100644 --- a/src/deep_taxon/nn/infer.py +++ b/src/deep_taxon/nn/infer.py @@ -1,4 +1,5 @@ import sys +import warnings import numpy as np import os from time import time @@ -10,6 +11,7 @@ from .lsf_environment import LSFEnvironment from pytorch_lightning.plugins.environments import SLURMEnvironment +from hdmf.common import get_hdf5io import glob @@ -36,7 +38,7 @@ def parse_args(*addl_args, argv=None): parser.add_argument('checkpoint', type=str, help='the checkpoint file to use for running inference') parser.add_argument('-o', '--output', type=str, help='the file to save outputs to', default=None) parser.add_argument('-f', '--resnet_features', action='store_true', help='drop classifier from ResNet model before inference', default=False) - parser.add_argument('-F', '--features', action='store_true', help='outputs are features i.e. do not softmax and compute predictions', default=False) + parser.add_argument('-F', '--features', action='store_true', help='do not softmax model output', default=False) parser.add_argument('-g', '--gpus', nargs='?', const=True, default=False, help='use GPU') parser.add_argument('-d', '--debug', action='store_true', default=False, help='run in debug mode i.e. only run two batches') @@ -52,6 +54,7 @@ def parse_args(*addl_args, argv=None): parser.add_argument('-p', '--maxprob', metavar='TOPN', nargs='?', const=1, default=0, type=int, help='store the top TOPN probablities of each output. By default, TOPN=1') parser.add_argument('-c', '--save_chunks', action='store_true', help='do store network outputs for each chunk', default=False) + parser.add_argument('-O', '--save_outputs', action='store_true', help='store full network outputs for each sequence. otherwise only store top-1 probability', default=False) env_grp = parser.add_argument_group("Resource Manager").add_mutually_exclusive_group() env_grp.add_argument("--lsf", default=False, action='store_true', help='running in an LSF environment') @@ -149,12 +152,19 @@ def process_args(args, comm=None): model = ResNetFeatures(model) args.features = True + io = get_hdf5io(args.input, 'r') + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + difile = io.read() + + difile.set_label_key(args.tgt_tax_lvl) + if size > 1: - dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True, comm=comm, size=size, rank=rank) + dataset = LazySeqDataset(difile=difile, hparams=argparse.Namespace(**model.hparams), keep_open=True, comm=comm, size=size, rank=rank) else: - dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True) + dataset = LazySeqDataset(difile=difile, hparams=argparse.Namespace(**model.hparams), keep_open=True) - tot_bases = dataset.orig_difile.get_seq_lengths().sum() + tot_bases = dataset.difile.get_seq_lengths().sum() args.logger.info(f'rank {rank} - processing {tot_bases} bases across {len(dataset)} samples') tmp_dset = dataset @@ -163,7 +173,7 @@ def process_args(args, comm=None): if args.num_workers > 0: kwargs['num_workers'] = args.num_workers kwargs['multiprocessing_context'] = 'spawn' - kwargs['worker_init_fn'] = dataset.worker_init + #kwargs['worker_init_fn'] = dataset.worker_init kwargs['persistent_workers'] = True loader = get_loader(tmp_dset, inference=True, **kwargs) @@ -212,11 +222,11 @@ def run_inference(argv=None): parallel_chunked_inf_summ(model, dataset, loader, args, f_kwargs) -def parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs): +def old_parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs): - n_samples = len(dataset.orig_difile.seq_table) - all_seq_ids = dataset.orig_difile.get_sequence_subset() - seq_lengths = dataset.orig_difile.get_seq_lengths() + n_samples = len(dataset.difile.seq_table) + all_seq_ids = dataset.difile.get_sequence_subset() + seq_lengths = dataset.difile.get_seq_lengths() f = h5py.File(args.output, 'w', **fkwargs) outputs_dset = None @@ -394,6 +404,126 @@ def cat(indices, outputs, labels, orig_lens, seq_ids): return ret +def parallel_chunked_inf_summ(model, dataset, loader, args, fkwargs): + + + from hdmf_ml import ResultsTable + from hdmf.backends.hdf5 import H5DataIO + + results = ResultsTable() + + n_samples = dataset.difile.n_seqs + + + if args.save_chunks: + print("I will not save chunks, sorry") + + prob_dset = results.add_topk_probabilities(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=float, fillvalue=0.0)) + label_dset = results.add_topk_classes(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=int, fillvalue=-1)) + target_dset = results.add_predicted_class(data=H5DataIO(shape=(n_samples,), dtype=int, fillvalue=-1)) + + outputs_dset = None + if args.save_outputs: + outputs_dset = results.add_predicted_probability(data=H5DataIO(shape=(n_samples, args.maxprob), dtype=float, fillvalue=0.0)) + + io = get_hdf5io(args.output, 'w') + io.write(results) + + check = lambda dset: dset.data.dataset if dset is not None else None + + prob_dset = check(prob_dset) + label_dset = check(label_dset) + target_dset = check(target_dset) + + # to-write queues - we use those so we're not doing I/O at every iteration + outputs_q = list() + targets_q = list() + counts_q = list() + seqs_q = list() + + # write what's in the to-write queues + if not hasattr(args, 'n_seqs'): + args.n_seqs = 500 + + # send model to GPU + model.to(args.device) + + uniq_labels = set() + for idx, _outputs, _labels, _orig_lens, _seq_ids in get_outputs(model, loader, args.device, debug=args.debug, chunks=args.n_batches, prog_bar=dataset.rank==0): + + seqs, counts = np.unique(_seq_ids, return_counts=True) + outputs_sum = list() + labels = list() + for i in seqs: + mask = _seq_ids == i + outputs_sum.append(_outputs[mask].sum(axis=0)) + labels.append(_labels[mask][0]) + uniq_labels.update(_labels) + + # Add the first sum of this iteration to the last sum of the + # previous iteration if they belong to the same sequence + if len(seqs_q) > 0 and seqs_q[-1] == seqs[0]: + outputs_q[-1] += outputs_sum[0] + counts_q[-1] += counts[0] + # drop the first sum so we don't end up with duplicates + seqs = seqs[1:] + counts = counts[1:] + outputs_sum = outputs_sum[1:] + labels = labels[1:] + + outputs_q.extend(outputs_sum) + seqs_q.extend(seqs) + counts_q.extend(counts) + targets_q.extend(labels) + + # write when we get above a certain number of sequences + if len(outputs_q) > args.n_seqs: + idx = seqs_q[:-1] + args.logger.debug(f"rank {dataset.rank} - saving these sequences {idx}") + # compute mean from sums + for i in range(len(idx)): + outputs_q[i] /= counts_q[i] + + if outputs_dset is not None: + outputs_dset[idx] = outputs_q[:-1] + target_dset[idx] = targets_q[:-1] + + prob_dset[idx], label_dset[idx] = get_topk(args.maxprob, outputs_q[:-1]) + + outputs_q = outputs_q[-1:] + seqs_q = seqs_q[-1:] + counts_q = counts_q[-1:] + targets_q = targets_q[-1:] + + args.logger.debug(f"rank {dataset.rank} - saving these sequences {seqs_q}") + args.logger.debug(f"rank {dataset.rank} - came across these labels {list(sorted(uniq_labels))}") + + # clean up what's left in the to-write queue + for i in range(len(seqs_q)): + outputs_q[i] /= counts_q[i] + + if outputs_dset is not None: + outputs_dset[seqs_q] = outputs_q + + target_dset[seqs_q] = targets_q + prob_dset[seqs_q], label_dset[seqs_q] = get_topk(args.maxprob, outputs_q) + + args.logger.info(f'rank {dataset.rank} - closing {args.output}') + + io.close() + + +def get_topk(topk, outputs): + k = outputs[0].shape[0] - topk + maxprobs = list() + maxclses = list() + for i in range(len(outputs)): + topk_idx = np.argpartition(outputs[i], k)[k:] + topk_idx = topk_idx[np.argsort(outputs[i][topk_idx])][::-1] + maxprobs.append(outputs[i][topk_idx]) + maxclses.append(topk_idx) + return maxprobs, maxclses + from . import models # noqa: E402 if __name__ == '__main__': diff --git a/src/deep_taxon/nn/loader.py b/src/deep_taxon/nn/loader.py index a25757d..d4df21b 100644 --- a/src/deep_taxon/nn/loader.py +++ b/src/deep_taxon/nn/loader.py @@ -138,8 +138,7 @@ def dataset_stats(argv=None): if comm is not None: kwargs['comm'] = comm before = time() - dataset = LazySeqDataset(**kwargs) - dataset.load(sequence=False) + dataset = LazySeqDataset(load=False, **kwargs) after = time() io.close() diff --git a/src/deep_taxon/sequence/dna_table.py b/src/deep_taxon/sequence/dna_table.py index bcb870d..33d38d8 100644 --- a/src/deep_taxon/sequence/dna_table.py +++ b/src/deep_taxon/sequence/dna_table.py @@ -779,7 +779,7 @@ def get_label_classes(self): return self._classes def get_seq_lengths(self): - self.lengths.copy() + return self.lengths.copy() def get_counts(self, orig=False): """