Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file modified data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz
Binary file not shown.
30 changes: 30 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import multiprocessing as mp


Expand Down Expand Up @@ -482,3 +483,32 @@ def wrapper(*args, **kwargs):
return torch.cat(results)

return wrapper


def create_optimized_dataloader(
dataset,
batch_size: int,
shuffle: bool = True,
collate_fn=None,
num_workers: int = 2,
) -> DataLoader:
"""Create a DataLoader with optimizations for GPU training.

Args:
dataset: PyTorch dataset
batch_size: Batch size for the dataloader
shuffle: Whether to shuffle the data
collate_fn: Optional collate function
num_workers: Number of worker processes for data loading
Returns:
DataLoader with GPU optimization settings
"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
num_workers=num_workers,
persistent_workers=True,
pin_memory=torch.cuda.is_available(),
)
8 changes: 5 additions & 3 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def update_neutral_probs(self):

In this case it's the neutral codon probabilities.
"""
neutral_codon_probs_l = []
neutral_codon_probs_light = []

for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
Expand All @@ -80,11 +80,13 @@ def update_neutral_probs(self):
neutral_codon_probs, (0, 0, 0, pad_len), value=SMALL_PROB
)

neutral_codon_probs_l.append(neutral_codon_probs)
neutral_codon_probs_light.append(neutral_codon_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l))
self.log_neutral_codon_probss = torch.log(
torch.stack(neutral_codon_probs_light)
)

def __getitem__(self, idx):
return {
Expand Down
41 changes: 41 additions & 0 deletions netam/data_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Constants defining data format for pcp dataframes and other data."""

from collections import defaultdict


_pcp_df_differentiated_columns = {
"parent": str,
"child": str,
"v_gene": str,
# These should be nullable, because they may be missing in combined
# heavy/light bulk dataframes.
"cdr1_codon_start": "Int64",
"cdr1_codon_end": "Int64",
"cdr2_codon_start": "Int64",
"cdr2_codon_end": "Int64",
"cdr3_codon_start": "Int64",
"cdr3_codon_end": "Int64",
"j_gene": str,
"v_family": str,
}

_pcp_df_undifferentiated_columns = {
"sample_id": str,
"family": str,
"parent_name": str,
"child_name": str,
"branch_length": float,
"depth": int,
"distance": float,
"parent_is_naive": bool,
"child_is_leaf": bool,
}

_all_pcp_df_columns = (
defaultdict(lambda: "object")
| {col + "_heavy": dtype for col, dtype in _pcp_df_differentiated_columns.items()}
| {col + "_light": dtype for col, dtype in _pcp_df_differentiated_columns.items()}
| _pcp_df_undifferentiated_columns
)

_required_pcp_df_columns = ("parent", "child", "v_gene")
6 changes: 3 additions & 3 deletions netam/ddsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class DDSMDataset(DXSMDataset):

def update_neutral_probs(self):
neutral_aa_probs_l = []
neutral_aa_probs_light = []

for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
Expand Down Expand Up @@ -60,11 +60,11 @@ def update_neutral_probs(self):
# Here we zero out masked positions.
neutral_aa_probs *= mask[:, None]

neutral_aa_probs_l.append(neutral_aa_probs)
neutral_aa_probs_light.append(neutral_aa_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_l))
self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_light))

def __getitem__(self, idx):
return {
Expand Down
8 changes: 5 additions & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def update_neutral_probs(self):

This is the case of the DNSM, but the DDSM will override this method.
"""
neutral_aa_mut_prob_l = []
neutral_aa_mut_prob_light = []

for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
Expand Down Expand Up @@ -76,11 +76,13 @@ def update_neutral_probs(self):
# Here we zero out masked positions.
neutral_aa_mut_probs *= mask

neutral_aa_mut_prob_l.append(neutral_aa_mut_probs)
neutral_aa_mut_prob_light.append(neutral_aa_mut_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_mut_probss = torch.log(torch.stack(neutral_aa_mut_prob_l))
self.log_neutral_aa_mut_probss = torch.log(
torch.stack(neutral_aa_mut_prob_light)
)

def __getitem__(self, idx):
return {
Expand Down
16 changes: 9 additions & 7 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from warnings import warn
from abc import ABC, abstractmethod
import copy

import torch

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
torch.set_num_threads(1)

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -35,6 +29,14 @@
AA_PADDING_TOKEN,
)

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
torch.set_num_threads(1)

# Enable cuDNN autotuner for better GPU performance
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


class DXSMDataset(framework.BranchLengthDataset, ABC):
# Not defining model_type here; instead defining it in subclasses.
Expand Down Expand Up @@ -185,7 +187,7 @@ def of_pcp_df(
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert (
"nt_rates_l" in pcp_df.columns
"nt_rates_light" in pcp_df.columns
), "pcp_df must have a neutral nt_rates column"
# use sequences.prepare_heavy_light_pair and the resulting
# added_indices to get the parent and child sequences and neutral model
Expand Down
Loading