Skip to content

Commit 22d7959

Browse files
committed
fix: fix black and mypi issues
1 parent 4a531a9 commit 22d7959

File tree

3 files changed

+70
-84
lines changed

3 files changed

+70
-84
lines changed

gen_surv/__init__.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from importlib.metadata import PackageNotFoundError, version
77

88
from .aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull
9-
10-
# Helper functions
119
from .bivariate import sample_bivariate_distribution
1210
from .censoring import (
1311
CensoringModel,
@@ -22,54 +20,43 @@
2220
)
2321
from .cmm import gen_cmm
2422
from .competing_risks import gen_competing_risks, gen_competing_risks_weibull
25-
26-
# Individual generators
2723
from .cphm import gen_cphm
2824
from .export import export_dataset
29-
from .integration import to_sksurv
30-
31-
# Main interface
3225
from .interface import generate
3326
from .mixture import cure_fraction_estimate, gen_mixture_cure
3427
from .piecewise import gen_piecewise_exponential
3528
from .sklearn_adapter import GenSurvDataGenerator
3629
from .tdcm import gen_tdcm
3730
from .thmm import gen_thmm
3831

32+
# Get package version
33+
try:
34+
__version__ = version("gen_surv")
35+
except PackageNotFoundError: # pragma: no cover - fallback when package not installed
36+
__version__ = "0.0.0"
37+
3938
# Visualization tools (requires matplotlib and lifelines)
4039
try:
41-
from .visualization import describe_survival # noqa: F401
42-
from .visualization import plot_covariate_effect # noqa: F401
43-
from .visualization import plot_hazard_comparison # noqa: F401
44-
from .visualization import plot_survival_curve # noqa: F401
40+
from .visualization import ( # noqa: F401
41+
describe_survival,
42+
plot_covariate_effect,
43+
plot_hazard_comparison,
44+
plot_survival_curve,
45+
)
4546

4647
_has_visualization = True
4748
except ImportError:
4849
_has_visualization = False
4950

50-
"""Top-level package for ``gen_surv``.
51-
52-
This module exposes the :func:`generate` function and provides access to the
53-
package version via ``__version__``.
54-
"""
55-
56-
from importlib.metadata import PackageNotFoundError, version
57-
58-
from .interface import generate
59-
51+
# Optional scikit-survival integration
6052
try:
61-
__version__ = version("gen_surv")
62-
except PackageNotFoundError: # pragma: no cover - fallback when package not installed
63-
__version__ = "0.0.0"
53+
from .integration import from_sksurv, to_sksurv # noqa: F401
6454

65-
# Optional imports - only available if dependencies are installed
66-
try:
67-
from .integration import to_sksurv, from_sksurv
68-
__all__ = ["generate", "__version__", "to_sksurv", "from_sksurv"]
55+
_has_sksurv = True
6956
except ImportError:
70-
# scikit-survival not available
71-
__all__ = ["generate", "__version__"]
57+
_has_sksurv = False
7258

59+
# Define exports
7360
__all__ = [
7461
# Main interface
7562
"generate",
@@ -87,7 +74,7 @@
8774
"gen_mixture_cure",
8875
"cure_fraction_estimate",
8976
"gen_piecewise_exponential",
90-
# Helpers
77+
# Helper functions
9178
"sample_bivariate_distribution",
9279
"runifcens",
9380
"rexpocens",
@@ -99,11 +86,13 @@
9986
"GammaCensoring",
10087
"CensoringModel",
10188
"export_dataset",
102-
"to_sksurv",
10389
"GenSurvDataGenerator",
10490
]
10591

106-
# Add visualization tools to __all__ if available
92+
# Add optional exports if available
93+
if _has_sksurv:
94+
__all__.extend(["to_sksurv", "from_sksurv"])
95+
10796
if _has_visualization:
10897
__all__.extend(
10998
[

gen_surv/integration.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55

66
try:
77
from sksurv.util import Surv
8+
89
SKSURV_AVAILABLE = True
910
except ImportError:
1011
SKSURV_AVAILABLE = False
1112

1213

13-
def to_sksurv(df: pd.DataFrame, time_col="time", event_col="status"):
14+
def to_sksurv(
15+
df: pd.DataFrame, time_col: str = "time", event_col: str = "status"
16+
) -> np.ndarray:
1417
"""
1518
Convert a pandas DataFrame to a scikit-survival structured array.
16-
19+
1720
Parameters
1821
----------
1922
df : pd.DataFrame
@@ -22,12 +25,12 @@ def to_sksurv(df: pd.DataFrame, time_col="time", event_col="status"):
2225
Name of the column containing survival times
2326
event_col : str, default "status"
2427
Name of the column containing event indicators (0/1 or boolean)
25-
28+
2629
Returns
2730
-------
2831
y : structured array
2932
Structured array suitable for scikit-survival functions
30-
33+
3134
Raises
3235
------
3336
ImportError
@@ -37,49 +40,50 @@ def to_sksurv(df: pd.DataFrame, time_col="time", event_col="status"):
3740
"""
3841
if not SKSURV_AVAILABLE:
3942
raise ImportError("scikit-survival is required but not installed")
40-
43+
4144
if df.empty:
4245
# Handle empty DataFrame case by creating a minimal valid structured array
4346
# This avoids the "event indicator must be binary" error for empty arrays
4447
return np.array([], dtype=[(event_col, bool), (time_col, float)])
45-
48+
4649
if time_col not in df.columns:
4750
raise ValueError(f"Column '{time_col}' not found in DataFrame")
4851
if event_col not in df.columns:
4952
raise ValueError(f"Column '{event_col}' not found in DataFrame")
50-
53+
5154
return Surv.from_dataframe(event_col, time_col, df)
5255

5356

54-
def from_sksurv(y: np.ndarray, time_col="time", event_col="status"):
57+
def from_sksurv(
58+
y: np.ndarray, time_col: str = "time", event_col: str = "status"
59+
) -> pd.DataFrame:
5560
"""
5661
Convert a scikit-survival structured array to a pandas DataFrame.
57-
62+
5863
Parameters
5964
----------
6065
y : structured array
6166
Structured array from scikit-survival
6267
time_col : str, default "time"
6368
Name for the time column in the resulting DataFrame
64-
event_col : str, default "status"
69+
event_col : str, default "status"
6570
Name for the event column in the resulting DataFrame
66-
71+
6772
Returns
6873
-------
6974
df : pd.DataFrame
7075
DataFrame with time and event columns
7176
"""
7277
if not SKSURV_AVAILABLE:
7378
raise ImportError("scikit-survival is required but not installed")
74-
79+
7580
if len(y) == 0:
7681
return pd.DataFrame({time_col: [], event_col: []})
77-
82+
7883
# Extract field names from structured array
7984
event_field = y.dtype.names[0]
8085
time_field = y.dtype.names[1]
81-
82-
return pd.DataFrame({
83-
time_col: y[time_field],
84-
event_col: y[event_field].astype(int)
85-
})
86+
87+
return pd.DataFrame(
88+
{time_col: y[time_field], event_col: y[event_field].astype(int)}
89+
)

tests/test_integration_sksurv.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,17 @@
44
import pandas as pd
55
import pytest
66

7-
from gen_surv.integration import to_sksurv, from_sksurv
7+
from gen_surv.integration import from_sksurv, to_sksurv
88

99

1010
def test_to_sksurv_basic():
1111
"""Test basic conversion from DataFrame to sksurv format."""
1212
pytest.importorskip("sksurv.util")
13-
14-
df = pd.DataFrame({
15-
"time": [1.0, 2.0, 3.0],
16-
"status": [1, 0, 1]
17-
})
18-
13+
14+
df = pd.DataFrame({"time": [1.0, 2.0, 3.0], "status": [1, 0, 1]})
15+
1916
arr = to_sksurv(df)
20-
17+
2118
assert len(arr) == 3
2219
assert arr.dtype.names == ("status", "time")
2320
assert list(arr["time"]) == [1.0, 2.0, 3.0]
@@ -27,49 +24,48 @@ def test_to_sksurv_basic():
2724
def test_to_sksurv_custom_columns():
2825
"""Test conversion with custom column names."""
2926
pytest.importorskip("sksurv.util")
30-
31-
df = pd.DataFrame({
32-
"survival_time": [1.0, 2.0],
33-
"event": [1, 0]
34-
})
35-
27+
28+
df = pd.DataFrame({"survival_time": [1.0, 2.0], "event": [1, 0]})
29+
3630
arr = to_sksurv(df, time_col="survival_time", event_col="event")
37-
31+
3832
assert len(arr) == 2
3933
assert arr.dtype.names == ("event", "survival_time")
4034

4135

4236
def test_to_sksurv_empty_dataframe():
4337
"""Test conversion of empty DataFrame."""
4438
pytest.importorskip("sksurv.util")
45-
39+
4640
df = pd.DataFrame({"time": [], "status": []})
4741
arr = to_sksurv(df)
48-
42+
4943
assert len(arr) == 0
5044
assert arr.dtype.names == ("status", "time")
5145

5246

5347
def test_to_sksurv_missing_columns():
5448
"""Test error handling for missing columns."""
5549
pytest.importorskip("sksurv.util")
56-
50+
5751
df = pd.DataFrame({"time": [1.0, 2.0]})
58-
52+
5953
with pytest.raises(ValueError, match="Column 'status' not found"):
6054
to_sksurv(df)
6155

6256

6357
def test_from_sksurv_basic():
6458
"""Test conversion from sksurv format to DataFrame."""
6559
pytest.importorskip("sksurv.util")
66-
60+
6761
# Create a structured array manually
68-
arr = np.array([(True, 1.0), (False, 2.0), (True, 3.0)],
69-
dtype=[("status", bool), ("time", float)])
70-
62+
arr = np.array(
63+
[(True, 1.0), (False, 2.0), (True, 3.0)],
64+
dtype=[("status", bool), ("time", float)],
65+
)
66+
7167
df = from_sksurv(arr)
72-
68+
7369
assert len(df) == 3
7470
assert list(df.columns) == ["time", "status"]
7571
assert list(df["time"]) == [1.0, 2.0, 3.0]
@@ -79,27 +75,24 @@ def test_from_sksurv_basic():
7975
def test_from_sksurv_empty():
8076
"""Test conversion of empty structured array."""
8177
pytest.importorskip("sksurv.util")
82-
78+
8379
arr = np.array([], dtype=[("status", bool), ("time", float)])
8480
df = from_sksurv(arr)
85-
81+
8682
assert len(df) == 0
8783
assert list(df.columns) == ["time", "status"]
8884

8985

9086
def test_roundtrip_conversion():
9187
"""Test that conversion is bidirectional."""
9288
pytest.importorskip("sksurv.util")
93-
94-
original_df = pd.DataFrame({
95-
"time": [1.0, 2.5, 4.0],
96-
"status": [1, 0, 1]
97-
})
98-
89+
90+
original_df = pd.DataFrame({"time": [1.0, 2.5, 4.0], "status": [1, 0, 1]})
91+
9992
# Convert to sksurv and back
10093
arr = to_sksurv(original_df)
10194
result_df = from_sksurv(arr)
102-
95+
10396
pd.testing.assert_frame_equal(original_df, result_df)
10497

10598

0 commit comments

Comments
 (0)