Skip to content
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
git+https://github.com/hdmf-dev/hdmf.git@0efea00ff1d10c4021c8145e917d92566f7edb4c
seaborn==0.11.0
git+https://github.com/ajtritt/pytorch-lightning.git@99d2503373fe1b966cf7014c4ce7e7183766d48a
torch==1.6.0
git+https://github.com/ajtritt/pytorch-lightning.git@fb30942d2c47a95531e063ed35a22f8fba25be12
48 changes: 48 additions & 0 deletions src/exabiome/nn/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytorch_lightning as pl
import torch.nn.functional as F
import torch
Expand Down Expand Up @@ -94,6 +95,9 @@ def dataset_stats(argv=None):


def read_dataset(path):
for root, dirs, files in os.walk("/mnt/bb/ajtritt/"):
for filename in files:
print(rank, '-', filename)
hdmfio = get_hdf5io(path, 'r')
difile = hdmfio.read()
dataset = SeqDataset(difile)
Expand Down Expand Up @@ -392,6 +396,49 @@ def get_loader(dataset, distances=False, **kwargs):
return DataLoader(dataset, collate_fn=collater, **kwargs)


<<<<<<< HEAD
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed.

class DIDataModule(pl.LightningDataModule):

def __init__(self, hparams, inference=False):
self.hparams = hparams
self.inference = inference

def train_dataloader(self):
self._check_loaders()
return self.loaders['train']

def val_dataloader(self):
self._check_loaders()
return self.loaders['validate']

def test_dataloader(self):
self._check_loaders()
return self.loaders['test']



def _check_loaders(self):
"""
Load dataset if it has not been loaded yet
"""
dataset, io = process_dataset(self.hparams, inference=self._inference)
if self.hparams.load:
dataset.load()

kwargs = dict(random_state=self.hparams.seed,
batch_size=self.hparams.batch_size,
distances=self.hparams.manifold,
downsample=self.hparams.downsample)
kwargs.update(self.hparams.loader_kwargs)
if self._inference:
kwargs['distances'] = False
kwargs.pop('num_workers', None)
kwargs.pop('multiprocessing_context', None)
tr, te, va = train_test_loaders(dataset, **kwargs)
self.loaders = {'train': tr, 'test': te, 'validate': va}
self.dataset = dataset

=======
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line needs to be removed

class DeepIndexDataModule(pl.LightningDataModule):

def __init__(self, hparams, inference=False):
Expand Down Expand Up @@ -419,3 +466,4 @@ def val_dataloader(self):

def test_dataloader(self):
return self.loaders['test']
>>>>>>> master
190 changes: 116 additions & 74 deletions src/exabiome/nn/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ def read_outputs(path):
if 'viz_emb' in f:
ret['viz_emb'] = f['viz_emb'][:]
ret['labels'] = f['labels'][:]

# we won't have these three if we are looking
# at non-representatives
if 'train' in f:
ret['train_mask'] = f['train'][:]
if 'test' in f:
ret['test_mask'] = f['test'][:]
ret['outputs'] = f['outputs'][:]
if 'validate' in f:
ret['validate_mask'] = f['validate'][:]

ret['orig_lens'] = f['orig_lens'][:]
if 'seq_ids' in f:
ret['seq_ids'] = f['seq_ids'][:]
Expand Down Expand Up @@ -66,73 +70,78 @@ def plot_results(path, tvt=True, pred=True, fig_height=7, logger=None, name=None
labels = path['labels']
outputs = path['outputs']

viz_emb = None
if 'viz_emb' in path:
logger.info('found viz_emb')
viz_emb = path['viz_emb']
# else:
# logger.info('calculating UMAP embeddings for visualization')
# from umap import UMAP
# umap = UMAP(n_components=2)
# viz_emb = umap.fit_transform(outputs)
else:
logger.info('calculating UMAP embeddings for visualization')
from umap import UMAP
umap = UMAP(n_components=2)
viz_emb = umap.fit_transform(outputs)
n_plots = 1

color_labels = getattr(pred, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))
colors = np.array([class_pal[i] for i in color_labels])

# set up figure
fig_height = 7
plt.figure(figsize=(n_plots*fig_height, fig_height))

logger.info('plotting embeddings with species labels')
# plot embeddings
ax = plt.subplot(1, n_plots, plot_count)
plot_seq_emb(viz_emb, labels, ax, pal=class_pal)
if name is not None:
plt.title(name)
plot_count += 1

# plot train/validation/testing data
train_mask = None
test_mask = None
validate_mask = None
if tvt:
logger.info('plotting embeddings train/validation/test labels')
train_mask = path['train_mask']
test_mask = path['test_mask']
validate_mask = path['validate_mask']
pal = ['gray', 'red', 'yellow']
plt.subplot(1, n_plots, plot_count)
dsubs = ['train', 'validation', 'test'] # data subsets
dsub_handles = list()
for (mask, dsub, col) in zip([train_mask, validate_mask, test_mask], dsubs, pal):
plt.scatter(viz_emb[mask, 0], viz_emb[mask, 1], s=0.1, c=[col], label=dsub)
dsub_handles.append(Circle(0, 0, color=col))
plt.legend(dsub_handles, dsubs)
if viz_emb:
logger.info('plotting embeddings with species labels')
# plot embeddings
ax = plt.subplot(1, n_plots, plot_count)
plot_seq_emb(viz_emb, labels, ax, pal=class_pal)
if name is not None:
plt.title(name)
plot_count += 1

# plot train/validation/testing data
train_mask = None
test_mask = None
validate_mask = None
if tvt:
logger.info('plotting embeddings train/validation/test labels')
train_mask = path['train_mask']
test_mask = path['test_mask']
validate_mask = path['validate_mask']
pal = ['gray', 'red', 'yellow']
plt.subplot(1, n_plots, plot_count)
dsubs = ['train', 'validation', 'test'] # data subsets
dsub_handles = list()
for (mask, dsub, col) in zip([train_mask, validate_mask, test_mask], dsubs, pal):
plt.scatter(viz_emb[mask, 0], viz_emb[mask, 1], s=0.1, c=[col], label=dsub)
dsub_handles.append(Circle(0, 0, color=col))
plt.legend(dsub_handles, dsubs)
plot_count += 1

# run some predictions and plot report
if pred is not False:
if pred is None or pred is True:
logger.info('No classifier given, using RandomForestClassifier(n_estimators=30)')
pred = RandomForestClassifier(n_estimators=30)
elif not (hasattr(pred, 'fit') and hasattr(pred, 'predict')):
raise ValueError("argument 'pred' must be a classifier with an SKLearn interface")

X_test = outputs
y_pred = pred
y_test = labels
if not hasattr(pred, 'classes_'):
train_mask = path['train_mask']
test_mask = path['test_mask']
X_train = outputs[train_mask]
y_train = labels[train_mask]
logger.info(f'training classifier {pred}')
pred.fit(X_train, y_train)
X_test = outputs[test_mask]
y_test = labels[test_mask]
logger.info(f'getting predictions')
y_pred = pred.predict(X_test)
if not isinstance(pred, (np.ndarray, list)):
if pred is None or pred is True:
logger.info('No classifier given, using RandomForestClassifier(n_estimators=30)')
pred = RandomForestClassifier(n_estimators=30)
elif not (hasattr(pred, 'fit') and hasattr(pred, 'predict')):
raise ValueError("argument 'pred' must be a classifier with an SKLearn interface")

X_test = outputs
if not hasattr(pred, 'classes_'):
train_mask = path['train_mask']
test_mask = path['test_mask']
X_train = outputs[train_mask]
y_train = labels[train_mask]
logger.info(f'training classifier {pred}')
pred.fit(X_train, y_train)
X_test = outputs[test_mask]
y_test = labels[test_mask]
logger.info(f'getting predictions')
y_pred = pred.predict(X_test)

logger.info(f'plotting classification report')
# plot classification report
Expand All @@ -156,15 +165,15 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
viz_emb = None
if 'viz_emb' in path:
viz_emb = path['viz_emb']
else:
viz_emb = UMAP(n_components=2).fit_transform(X)

uniq_seqs = np.unique(seq_ids)
X_mean = np.zeros((uniq_seqs.shape[0], outputs.shape[1]))
X_median = np.zeros((uniq_seqs.shape[0], outputs.shape[1]))
y = np.zeros(uniq_seqs.shape[0], dtype=int)
seq_len = np.zeros(uniq_seqs.shape[0], dtype=int)
seq_viz = np.zeros((uniq_seqs.shape[0], 2))
seq_viz = None
if viz_emb is not None:
seq_viz = np.zeros((uniq_seqs.shape[0], 2))

for seq_i, seq in enumerate(uniq_seqs):
seq_mask = seq_ids == seq
Expand All @@ -174,17 +183,31 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
y[seq_i] = uniq_labels[0]
X_mean[seq_i] = outputs[seq_mask].mean(axis=0)
X_median[seq_i] = np.median(outputs[seq_mask], axis=0)
seq_viz[seq_i] = viz_emb[seq_mask].mean(axis=0)
if seq_viz is not None:
seq_viz[seq_i] = viz_emb[seq_mask].mean(axis=0)
seq_len[seq_i] = olens[seq_mask].sum()

seq_len = np.log10(seq_len)

color_labels = getattr(clf, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))
fig, axes = None, None
figsize_factor = 7
class_pal = None
if isinstance(clf, (list, np.ndarray)):
nrows = 2
ncols = 1
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(nrows*figsize_factor, ncols*figsize_factor))
axes = np.expand_dims(axes, axis=1)
all_preds = np.argmax(outputs, axis=1)
class_pal = get_color_markers(outputs.shape[1])
else:
color_labels = getattr(clf, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))

fig, axes = plt.subplots(nrows=3, ncols=3, sharey='row', figsize=(21, 21))
nrows = 3 if seq_viz is not None else 2
ncols = 3
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharey='row', figsize=(nrows*figsize_factor, ncols*figsize_factor))

# classifier from MEAN of outputs
output_mean_preds = clf.predict(X_mean)
Expand All @@ -194,9 +217,10 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
output_median_preds = clf.predict(X_median)
make_plots(y, output_median_preds, axes[:,1], class_pal, seq_len, 'Median classification', seq_viz)

# classifier from voting with chunk predictions
all_preds = clf.predict(outputs)
vote_preds = np.zeros_like(output_mean_preds)
# classifier from voting with chunk predictions
all_preds = clf.predict(outputs)

vote_preds = np.zeros(X_mean.shape[0], dtype=int)
for seq_i, seq in enumerate(uniq_seqs):
seq_mask = seq_ids == seq
vote_preds[seq_i] = stats.mode(all_preds[seq_mask])[0][0]
Expand Down Expand Up @@ -387,6 +411,9 @@ def summarize(argv=None):
parser.add_argument('-A', '--aggregate_chunks', action='store_true', default=False,
help='aggregate chunks within sequences and perform analysis')
parser.add_argument('-o', '--outdir', type=str, default=None, help='the output directory for figures')
type_group = parser.add_argument_group('Problem type').add_mutually_exclusive_group()
type_group.add_argument('-C', '--classify', action='store_true', help='run a classification problem', default=False)
type_group.add_argument('-M', '--manifold', action='store_true', help='run a manifold learning problem', default=False)

args = parser.parse_args(args=argv)
if os.path.isdir(args.input):
Expand All @@ -405,23 +432,38 @@ def summarize(argv=None):
fig_path = os.path.join(outdir, 'summary.png')
logger = parse_logger('')

plt.figure(figsize=(21, 7))
pretrained = False
if args.classifier is not None:
with open(args.classifier, 'rb') as f:
pred = pickle.load(f)
pretrained = True
else:
pred = RandomForestClassifier(n_estimators=30)
outputs = read_outputs(args.input)
pred = plot_results(outputs, pred=pred, name='/'.join(args.input.split('/',)[-2:]), logger=logger)
if args.classify:
plt.figure(figsize=(7, 7))
labels = outputs['labels']
model_outputs = outputs['outputs']
if 'test_mask' in outputs:
mask = outputs['test_mask']
labels = labels[mask]
model_outputs = model_outputs[mask]

pred = np.argmax(model_outputs, axis=1)
class_pal = get_color_markers(model_outputs.shape[1])
colors = np.array([class_pal[i] for i in labels])
ax = plt.gca()
plot_clf_report(labels, pred, ax=ax, pal=class_pal)
else:
plt.figure(figsize=(21, 7))
pretrained = False
if args.classifier is not None:
with open(args.classifier, 'rb') as f:
pred = pickle.load(f)
pretrained = True
else:
pred = RandomForestClassifier(n_estimators=30)
pred = plot_results(outputs, pred=pred, name='/'.join(args.input.split('/',)[-2:]), logger=logger)
if not pretrained:
clf_path = os.path.join(outdir, 'summary.rf.pkl')
logger.info(f'saving classifier to {clf_path}')
with open(clf_path, 'wb') as f:
pickle.dump(pred, f)
logger.info(f'saving figure to {fig_path}')
plt.savefig(fig_path, dpi=100)
if not pretrained:
clf_path = os.path.join(outdir, 'summary.rf.pkl')
logger.info(f'saving classifier to {clf_path}')
with open(clf_path, 'wb') as f:
pickle.dump(pred, f)

if args.aggregate_chunks:
logger.info(f'running summary by aggregating chunks within sequences')
Expand Down
Loading