Skip to content

Commit 9961ad2

Browse files
feat: expose dev extras (#65)
1 parent 717cee8 commit 9961ad2

24 files changed

+80
-70
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ repos:
1515
rev: v1.15.0
1616
hooks:
1717
- id: mypy
18+
pass_filenames: false
19+
args: [--config-file=pyproject.toml, gen_surv]

examples/run_aft_weibull.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import sys
77

88
import matplotlib.pyplot as plt
9-
import numpy as np
109
import pandas as pd
1110

1211
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

examples/run_competing_risks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import matplotlib.pyplot as plt
99
import numpy as np
10-
import pandas as pd
1110

1211
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
1312

gen_surv/__init__.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from .bivariate import sample_bivariate_distribution
1212
from .censoring import (
1313
CensoringModel,
14+
GammaCensoring,
15+
LogNormalCensoring,
16+
WeibullCensoring,
1417
rexpocens,
18+
rgammacens,
19+
rlognormcens,
1520
runifcens,
1621
rweibcens,
17-
rlognormcens,
18-
rgammacens,
19-
WeibullCensoring,
20-
LogNormalCensoring,
21-
GammaCensoring,
2222
)
2323
from .cmm import gen_cmm
2424
from .competing_risks import gen_competing_risks, gen_competing_risks_weibull
@@ -38,12 +38,10 @@
3838

3939
# Visualization tools (requires matplotlib and lifelines)
4040
try:
41-
from .visualization import (
42-
describe_survival,
43-
plot_covariate_effect,
44-
plot_hazard_comparison,
45-
plot_survival_curve,
46-
)
41+
from .visualization import describe_survival # noqa: F401
42+
from .visualization import plot_covariate_effect # noqa: F401
43+
from .visualization import plot_hazard_comparison # noqa: F401
44+
from .visualization import plot_survival_curve # noqa: F401
4745

4846
_has_visualization = True
4947
except ImportError:

gen_surv/_validation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class PositiveSequenceError(ValidationError):
5555
"""Raised when a sequence contains non-positive elements."""
5656

5757
def __init__(self, name: str, seq: Sequence[Any]) -> None:
58-
super().__init__(f"All elements in '{name}' must be greater than 0; got {seq!r}")
58+
super().__init__(
59+
f"All elements in '{name}' must be greater than 0; got {seq!r}"
60+
)
5961

6062

6163
class ListOfListsError(ValidationError):
@@ -69,9 +71,7 @@ class ParameterError(ValidationError):
6971
"""Raised when a parameter falls outside its allowed range."""
7072

7173
def __init__(self, name: str, value: Any, constraint: str) -> None:
72-
super().__init__(
73-
f"Invalid value for '{name}': {value!r}. {constraint}"
74-
)
74+
super().__init__(f"Invalid value for '{name}': {value!r}. {constraint}")
7575

7676

7777
_ALLOWED_CENSORING = {"uniform", "exponential"}
@@ -108,16 +108,19 @@ def _to_float_array(seq: Sequence[Any], name: str) -> NDArray[np.float64]:
108108
except (TypeError, ValueError) as exc:
109109
raise NumericSequenceError(name, seq) from exc
110110

111+
111112
def ensure_numeric_sequence(seq: Sequence[Any], name: str) -> None:
112113
"""Ensure all elements of ``seq`` are numeric."""
113114
_to_float_array(seq, name)
114115

116+
115117
def ensure_positive_sequence(seq: Sequence[float], name: str) -> None:
116118
"""Ensure all elements of ``seq`` are positive."""
117119
arr = _to_float_array(seq, name)
118120
if np.any(arr <= 0):
119121
raise PositiveSequenceError(name, seq)
120122

123+
121124
def ensure_censoring_model(model_cens: str) -> None:
122125
"""Validate that the censoring model is supported."""
123126
ensure_in_choices(model_cens, "model_cens", _ALLOWED_CENSORING)

gen_surv/bivariate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from typing import Sequence
2+
13
import numpy as np
24
from numpy.typing import NDArray
3-
from typing import Sequence
45

56
from .validate import validate_dg_biv_inputs
67

7-
88
_CHI2_SCALE = 0.5
99
_CLIP_EPS = 1e-10
1010

@@ -43,7 +43,9 @@ def sample_bivariate_distribution(
4343
mean = [0, 0]
4444
cov = [[1, corr], [corr, 1]]
4545
z = np.random.multivariate_normal(mean, cov, size=n)
46-
u = 1 - np.exp(-_CHI2_SCALE * z**2) # transform normals to uniform via chi-squared approx
46+
u = 1 - np.exp(
47+
-_CHI2_SCALE * z**2
48+
) # transform normals to uniform via chi-squared approx
4749
u = np.clip(u, _CLIP_EPS, 1 - _CLIP_EPS) # avoid infs in tails
4850

4951
# Step 2: Transform to marginals

gen_surv/censoring.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Protocol
2+
13
import numpy as np
24
from numpy.typing import NDArray
3-
from typing import Protocol
45

56

67
class CensoringFunc(Protocol):

gen_surv/cmm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Sequence, TypedDict
2+
13
import numpy as np
24
import pandas as pd
3-
from typing import Sequence, TypedDict
45

56
from gen_surv.censoring import CensoringFunc, rexpocens, runifcens
67
from gen_surv.validate import validate_gen_cmm_inputs

gen_surv/competing_risks.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import pandas as pd
1212

1313
from ._validation import (
14+
ParameterError,
1415
ensure_censoring_model,
1516
ensure_in_choices,
1617
ensure_positive_sequence,
1718
ensure_sequence_length,
18-
ParameterError,
1919
)
2020
from .censoring import rexpocens, runifcens
2121

@@ -109,9 +109,7 @@ def gen_competing_risks(
109109
n_covariates = 2 # Default number of covariates
110110

111111
# Set default covariate parameters if not provided
112-
ensure_in_choices(
113-
covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}
114-
)
112+
ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"})
115113
if covariate_params is None:
116114
if covariate_dist == "normal":
117115
covariate_params = {"mean": 0.0, "std": 1.0}
@@ -309,9 +307,7 @@ def gen_competing_risks_weibull(
309307
n_covariates = 2 # Default number of covariates
310308

311309
# Set default covariate parameters if not provided
312-
ensure_in_choices(
313-
covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}
314-
)
310+
ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"})
315311
if covariate_params is None:
316312
if covariate_dist == "normal":
317313
covariate_params = {"mean": 0.0, "std": 1.0}

gen_surv/interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
>>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0)
77
"""
88

9-
from typing import Any, Literal, Protocol, Dict
9+
from typing import Any, Dict, Literal, Protocol
1010

1111
import pandas as pd
1212

@@ -18,6 +18,7 @@
1818
from gen_surv.piecewise import gen_piecewise_exponential
1919
from gen_surv.tdcm import gen_tdcm
2020
from gen_surv.thmm import gen_thmm
21+
2122
from ._validation import ensure_in_choices
2223

2324
# Type definitions for model names
@@ -35,6 +36,7 @@
3536
"piecewise_exponential",
3637
]
3738

39+
3840
# Interface for generator callables
3941
class DataGenerator(Protocol):
4042
def __call__(self, **kwargs: Any) -> pd.DataFrame: ...

0 commit comments

Comments
 (0)