Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/exabiome/nn/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from ..utils import check_argv, parse_logger
from .utils import process_gpus, process_model, process_output
from .loader import process_dataset, get_loader
from .loader import process_dataset, get_loader, DeepIndexDataModule
import glob
import argparse
import torch
Expand Down Expand Up @@ -124,8 +124,15 @@ def process_args(argv=None):
args.classify = True

# load the model and override batch size
model = process_model(args, inference=True)
hparams = torch.load(args.checkpoint, map_location=torch.device('cpu'))['hyper_parameters']
hparams = argparse.Namespace(**hparams)
hparams.checkpoint = args.checkpoint
model = process_model(hparams, inference=True)
model.set_inference(True)
model.freeze()

data_mod = DeepIndexDataModule(hparams)

if args.batch_size is not None:
model.hparams.batch_size = args.batch_size

Expand All @@ -142,16 +149,16 @@ def process_args(argv=None):
args.label_map = np.array([tid2idx[tid] for tid in rep_tid])
train_io.close()
elif args.loaders is None: # if an input file is not passed in, do all TVT data
args.loaders = {'train': model.train_dataloader(),
'validate': model.val_dataloader(),
'test': model.test_dataloader()}
args.difile = model.dataset.difile
args.loaders = {'train': data_mod.train_dataloader(),
'validate': data_mod.val_dataloader(),
'test': data_mod.test_dataloader()}
args.difile = data_mod.dataset.difile
args.label_map = None

# return the model, any arguments, and Lighting Trainer args just in case
# we want to use them down the line when we figure out how to use Lightning for
# inference
ret = [model, args, targs]
ret = [model, args, targs, data_mod]

return tuple(ret)

Expand All @@ -162,7 +169,7 @@ def run_inference(argv=None):
argv: a command-line string or argparse.Namespace object to use for running inference
If none are given, read from command-line i.e. like running argparse.ArgumentParser.parse_args
"""
model, args, addl_targs = process_args(argv=argv)
model, args, addl_targs, data_mod = process_args(argv=argv)
import h5py
import numpy as np
import os
Expand Down
3 changes: 2 additions & 1 deletion src/exabiome/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def process_model(args, inference=False):

if args.checkpoint is not None:
try:
model = model_cls.load_from_checkpoint(args.checkpoint, hparams=args)
hparams = torch.load(args.checkpoint, map_location=torch.device('cpu'))['hyper_parameters']
model = model_cls.load_from_checkpoint(args.checkpoint, hparams=hparams)
except RuntimeError as e:
if 'Missing key(s)' in e.args[0]:
raise RuntimeError(f'Unable to load checkpoint. Make sure {args.checkpoint} is a checkpoint for {args.model}') from e
Expand Down