Skip to content

Commit 7527fdb

Browse files
authored
Revert v3 update for netam.multihit (#149)
1 parent 6470095 commit 7527fdb

File tree

2 files changed

+24
-20
lines changed

2 files changed

+24
-20
lines changed

netam/multihit.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
apply the correction to existing codon probability predictions, we multiply the
99
probability of each child codon by the correction factor for its hit class, then
1010
renormalize. The correction factor for hit class 0 is fixed at 1.
11+
12+
NOTE: Unlike the rest of netam, this module is not updated to the v3 data format, since Thrifty isn't either.
1113
"""
1214

1315
import torch
@@ -436,10 +438,10 @@ def to_crepe(self):
436438
def hit_class_dataset_from_pcp_df(
437439
pcp_df: pd.DataFrame, branch_length_multiplier: int = 1.0
438440
) -> HitClassDataset:
439-
nt_parents = pcp_df["parent_heavy"].reset_index(drop=True)
440-
nt_children = pcp_df["child_heavy"].reset_index(drop=True)
441-
nt_rates = pcp_df["nt_rates_heavy"].reset_index(drop=True)
442-
nt_csps = pcp_df["nt_csps_heavy"].reset_index(drop=True)
441+
nt_parents = pcp_df["parent"].reset_index(drop=True)
442+
nt_children = pcp_df["child"].reset_index(drop=True)
443+
nt_rates = pcp_df["nt_rates"].reset_index(drop=True)
444+
nt_csps = pcp_df["nt_csps"].reset_index(drop=True)
443445

444446
return HitClassDataset(
445447
nt_parents,
@@ -455,10 +457,10 @@ def train_test_datasets_of_pcp_df(
455457
) -> Tuple[HitClassDataset, HitClassDataset]:
456458
"""Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing
457459
HitClassDataset."""
458-
nt_parents = pcp_df["parent_heavy"].reset_index(drop=True)
459-
nt_children = pcp_df["child_heavy"].reset_index(drop=True)
460-
nt_rates = pcp_df["nt_rates_heavy"].reset_index(drop=True)
461-
nt_csps = pcp_df["nt_csps_heavy"].reset_index(drop=True)
460+
nt_parents = pcp_df["parent"].reset_index(drop=True)
461+
nt_children = pcp_df["child"].reset_index(drop=True)
462+
nt_rates = pcp_df["nt_rates"].reset_index(drop=True)
463+
nt_csps = pcp_df["nt_csps"].reset_index(drop=True)
462464

463465
train_len = int(train_frac * len(nt_parents))
464466
train_parents, val_parents = nt_parents[:train_len], nt_parents[train_len:]
@@ -496,18 +498,12 @@ def prepare_pcp_df(
496498
497499
Returns the modified dataframe, which is the input dataframe modified in-place.
498500
"""
499-
pcp_df["parent_heavy"] = _trim_to_codon_boundary_and_max_len(
500-
pcp_df["parent_heavy"], site_count
501-
)
502-
pcp_df["child_heavy"] = _trim_to_codon_boundary_and_max_len(
503-
pcp_df["child_heavy"], site_count
504-
)
505-
pcp_df = pcp_df[pcp_df["parent_heavy"] != pcp_df["child_heavy"]].reset_index(
506-
drop=True
507-
)
501+
pcp_df["parent"] = _trim_to_codon_boundary_and_max_len(pcp_df["parent"], site_count)
502+
pcp_df["child"] = _trim_to_codon_boundary_and_max_len(pcp_df["child"], site_count)
503+
pcp_df = pcp_df[pcp_df["parent"] != pcp_df["child"]].reset_index(drop=True)
508504
ratess, cspss = framework.trimmed_shm_model_outputs_of_crepe(
509-
crepe, pcp_df["parent_heavy"]
505+
crepe, pcp_df["parent"]
510506
)
511-
pcp_df["nt_rates_heavy"] = ratess
512-
pcp_df["nt_csps_heavy"] = cspss
507+
pcp_df["nt_rates"] = ratess
508+
pcp_df["nt_csps"] = cspss
513509
return pcp_df

tests/test_multihit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@
8787
@pytest.fixture
8888
def mini_multihit_train_val_datasets():
8989
df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz")
90+
# Rename _heavy columns to drop _heavy and drop all _light columns
91+
# (multihit training is Thrifty territory, and not yet v3 format)
92+
df = df.rename(
93+
columns={
94+
col: col[: -len("_heavy")] for col in df.columns if col.endswith("_heavy")
95+
}
96+
)
97+
df = df.drop(columns=[col for col in df.columns if col.endswith("_light")])
9098
crepe = pretrained.load("ThriftyHumV0.2-45")
9199
df = multihit.prepare_pcp_df(df, crepe, 500)
92100
return multihit.train_test_datasets_of_pcp_df(df)

0 commit comments

Comments
 (0)