Skip to content

Commit e72cde7

Browse files
tomwardiocopybara-github
authored andcommitted
Fix tidy_anndata to return empty dataframe when there are no scores.
This should fix #39, where we may have an empty anndata object with gene_id and strand obs, but no rows. PiperOrigin-RevId: 874524624 Change-Id: I053d90491346cba679cd45069fda31104f575964
1 parent e6901f1 commit e72cde7

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

src/alphagenome/models/variant_scorers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import enum
2020
import itertools
21+
import math
2122
from typing import TypeVar
2223

2324
from alphagenome.models import dna_output
@@ -700,16 +701,18 @@ def tidy_anndata(
700701
'junction_Start',
701702
'junction_End',
702703
]
703-
if 'gene_id' in adata.obs and 'strand' in adata.obs:
704+
if math.prod(adata.X.shape) == 0:
705+
# Scores are empty, so we return an empty dataframe.
706+
return pd.DataFrame()
707+
elif 'gene_id' in adata.obs and 'strand' in adata.obs:
704708
# Scores are for a gene-based scorer.
705709
obs = adata.obs.rename({'strand': 'gene_strand'}, axis=1)
706-
obs['gene_id'] = obs['gene_id'].str.split('.', expand=True)[0] # Depatch.
710+
711+
# Remove patch number from gene_id.
712+
obs['gene_id'] = obs['gene_id'].str.split('.', expand=True).get(0)
707713
for col in gene_columns:
708714
if col not in obs:
709715
obs[col] = None
710-
elif adata.X.shape[0] == 0:
711-
# Scores are empty, so we return an empty dataframe.
712-
return pd.DataFrame()
713716
elif adata.X.shape[0] == 1 and adata.obs.empty:
714717
# Scores are for a non-gene-based scorer.
715718
obs = pd.DataFrame([[None] * len(gene_columns)], columns=gene_columns)

src/alphagenome/models/variant_scorers_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,13 @@ def test_tidy_anndata_splicing_sites(self):
877877
],
878878
)
879879

880+
def test_tidy_anndata_empty(self):
881+
adata = anndata.AnnData(
882+
X=np.ones((0, 3)), obs=pd.DataFrame(columns=['gene_id', 'strand'])
883+
)
884+
result = variant_scorers.tidy_anndata(adata)
885+
pd.testing.assert_frame_equal(result, pd.DataFrame())
886+
880887
def test_tidy_lists_of_anndata(self):
881888
# List of AnnDatas.
882889
anndata_list = [self.adata_gene_centric, self.adata_variant_centric]

0 commit comments

Comments
 (0)