Skip to content

Commit da29956

Browse files
committed
IO functions now replace NaNs with 0
1 parent c10bb34 commit da29956

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

grassp/io/read.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99
import protdata
10+
import scipy.sparse
1011

1112
# def read_alphastats(
1213
# loader: alphastats.BaseLoader,
@@ -92,6 +93,26 @@
9293
# return adata
9394

9495

96+
def _preprocess_adata(adata: anndata.AnnData) -> anndata.AnnData:
97+
"""Preprocess an AnnData object."""
98+
99+
# Replace NaNs with 0 in .X
100+
if isinstance(adata.X, np.ndarray):
101+
adata.X = np.nan_to_num(adata.X, nan=0)
102+
elif isinstance(adata.X, scipy.sparse.spmatrix):
103+
adata.X.data = np.nan_to_num(adata.X.data, nan=0, copy=False)
104+
105+
# Replace NaNs with 0 in all layers
106+
for layer in list(adata.layers.keys()):
107+
arr = adata.layers[layer]
108+
if isinstance(arr, np.ndarray):
109+
adata.layers[layer] = np.nan_to_num(arr, nan=0, copy=False)
110+
elif isinstance(arr, scipy.sparse.spmatrix):
111+
adata.layers[layer].data = np.nan_to_num(arr.data, nan=0, copy=False)
112+
113+
return adata
114+
115+
95116
def read_prolocdata(file_name: str, allow_nullable_strings: bool = False) -> anndata.AnnData:
96117
"""Read a prolocdata file and return an AnnData object.
97118
@@ -180,6 +201,7 @@ def read_prolocdata(file_name: str, allow_nullable_strings: bool = False) -> ann
180201
# Remove class version key if present
181202
metadata.pop(".__classVersion__", None)
182203
adata.uns["MIAPE_metadata"] = metadata
204+
_preprocess_adata(adata)
183205

184206
return adata
185207

@@ -214,6 +236,7 @@ def read_maxquant(*args, **kwargs) -> anndata.AnnData:
214236
>>> adata = gr.io.read_maxquant('proteinGroups.txt') # doctest: +SKIP
215237
"""
216238
adata = protdata.io.read_maxquant(*args, **kwargs)
239+
_preprocess_adata(adata)
217240
return adata.T
218241

219242

@@ -247,6 +270,7 @@ def read_fragpipe(*args, **kwargs) -> anndata.AnnData:
247270
>>> adata = gr.io.read_fragpipe('combined_protein.tsv') # doctest: +SKIP
248271
"""
249272
adata = protdata.io.read_fragpipe(*args, **kwargs)
273+
_preprocess_adata(adata)
250274
return adata.T
251275

252276

@@ -280,4 +304,5 @@ def read_diann(*args, **kwargs) -> anndata.AnnData:
280304
>>> adata = gr.io.read_diann('report.tsv') # doctest: +SKIP
281305
"""
282306
adata = protdata.io.read_diann(*args, **kwargs)
307+
_preprocess_adata(adata)
283308
return adata.T

grassp/tests/test_io.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,133 @@ def test_classversion_removed_from_metadata(self, mock_convert, mock_parse):
650650
# __classVersion__ should not be in MIAPE_metadata
651651
miape = result.uns["MIAPE_metadata"]
652652
assert ".__classVersion__" not in miape
653+
654+
655+
# ==============================================================================
656+
# Tests for _preprocess_adata (NaN handling)
657+
# ==============================================================================
658+
659+
660+
class TestPreprocessAdata:
661+
"""Test _preprocess_adata function for NaN replacement."""
662+
663+
def test_nan_replacement_in_X_dense(self):
664+
"""Test that NaNs in .X (dense) are replaced with 0."""
665+
# Create AnnData with NaNs in .X
666+
X = np.array([[1.0, np.nan, 3.0], [np.nan, 5.0, 6.0], [7.0, 8.0, np.nan]])
667+
adata = AnnData(X=X)
668+
669+
result = read._preprocess_adata(adata)
670+
671+
# Check NaNs are replaced with 0
672+
assert not np.any(np.isnan(result.X))
673+
expected = np.array([[1.0, 0.0, 3.0], [0.0, 5.0, 6.0], [7.0, 8.0, 0.0]])
674+
assert np.allclose(result.X, expected)
675+
676+
def test_nan_replacement_in_X_sparse(self):
677+
"""Test that NaNs in .X (sparse) are replaced with 0."""
678+
from scipy.sparse import csr_matrix
679+
680+
# Create sparse matrix with NaNs
681+
dense = np.array([[1.0, np.nan, 0.0], [0.0, 5.0, np.nan], [7.0, 0.0, 9.0]])
682+
X = csr_matrix(dense)
683+
adata = AnnData(X=X)
684+
685+
result = read._preprocess_adata(adata)
686+
687+
# Check NaNs are replaced with 0
688+
assert not np.any(np.isnan(result.X.data))
689+
690+
def test_nan_replacement_in_layers_dense(self):
691+
"""Test that NaNs in layers (dense) are replaced with 0."""
692+
X = np.array([[1.0, 2.0], [3.0, 4.0]])
693+
layer1 = np.array([[np.nan, 2.0], [3.0, np.nan]])
694+
layer2 = np.array([[1.0, np.nan], [np.nan, 4.0]])
695+
696+
adata = AnnData(X=X, layers={"layer1": layer1, "layer2": layer2})
697+
698+
result = read._preprocess_adata(adata)
699+
700+
# Check NaNs are replaced with 0 in all layers
701+
assert not np.any(np.isnan(result.layers["layer1"]))
702+
assert not np.any(np.isnan(result.layers["layer2"]))
703+
704+
def test_nan_replacement_in_layers_sparse(self):
705+
"""Test that NaNs in layers (sparse) are replaced with 0."""
706+
from scipy.sparse import csr_matrix
707+
708+
X = np.array([[1.0, 2.0], [3.0, 4.0]])
709+
layer_dense = np.array([[np.nan, 0.0], [3.0, np.nan]])
710+
layer_sparse = csr_matrix(layer_dense)
711+
712+
adata = AnnData(X=X, layers={"sparse_layer": layer_sparse})
713+
714+
result = read._preprocess_adata(adata)
715+
716+
# Check NaNs are replaced with 0
717+
assert not np.any(np.isnan(result.layers["sparse_layer"].data))
718+
719+
def test_return_value(self):
720+
"""Test that _preprocess_adata returns the modified AnnData."""
721+
X = np.array([[1.0, np.nan], [np.nan, 4.0]])
722+
adata = AnnData(X=X)
723+
724+
result = read._preprocess_adata(adata)
725+
726+
# Check that it returns an AnnData object
727+
assert isinstance(result, AnnData)
728+
# Check that it's the same object (modified in-place)
729+
assert result is adata
730+
731+
@patch("protdata.io.read_maxquant")
732+
def test_nan_handling_integration_maxquant(self, mock_read, tmp_path):
733+
"""Integration test: verify NaN handling with mock MaxQuant data."""
734+
# Create mock AnnData with NaNs (as protdata would return it)
735+
X_with_nans = np.array(
736+
[
737+
[100.0, np.nan, 150.0], # Sample 1
738+
[np.nan, 190.0, 160.0], # Sample 2
739+
[105.0, 210.0, np.nan], # Sample 3
740+
],
741+
dtype=float,
742+
)
743+
744+
obs = pd.DataFrame(
745+
{"sample_name": ["Sample_A", "Sample_B", "Sample_C"]},
746+
index=["Sample_A", "Sample_B", "Sample_C"],
747+
)
748+
749+
var = pd.DataFrame(
750+
{"Protein IDs": ["P00001", "P00002", "P00003"]},
751+
index=["P00001", "P00002", "P00003"],
752+
)
753+
754+
# Add a layer with NaNs too
755+
layer_with_nans = np.array(
756+
[[np.nan, 2.0, 3.0], [4.0, np.nan, 6.0], [7.0, 8.0, np.nan]], dtype=float
757+
)
758+
759+
mock_adata = AnnData(
760+
X=X_with_nans, obs=obs, var=var, layers={"pvals": layer_with_nans}
761+
)
762+
mock_read.return_value = mock_adata
763+
764+
# Create a dummy file (content doesn't matter since we're mocking)
765+
test_file = tmp_path / "proteinGroups.txt"
766+
test_file.write_text("dummy content")
767+
768+
# Call read_maxquant
769+
result = read.read_maxquant(str(test_file))
770+
771+
# Verify NaNs in .X are replaced with 0 (after transpose)
772+
assert not np.any(np.isnan(result.X)), "NaNs found in .X after preprocessing"
773+
774+
# Verify NaNs in layers are replaced with 0
775+
if "pvals" in result.layers:
776+
assert not np.any(
777+
np.isnan(result.layers["pvals"])
778+
), "NaNs found in layers after preprocessing"
779+
780+
# Verify dimensions after transpose (proteins in obs, samples in var)
781+
assert result.n_obs == 3 # 3 proteins
782+
assert result.shape[1] == 3 # 3 samples

0 commit comments

Comments
 (0)