From 47e4d7fbf82788367a5e67803dc93efab7d62caf Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Thu, 10 Jul 2025 23:16:19 +0100 Subject: [PATCH 01/19] feat: Add competing risks models and enhance data visualization --- .github/workflows/bump-version.yml | 38 +- .github/workflows/ci.yml | 74 ++++ examples/run_aft_weibull.py | 97 +++++ examples/run_competing_risks.py | 138 ++++++ gen_surv/__init__.py | 55 ++- gen_surv/aft.py | 212 ++++++++- gen_surv/cli.py | 150 ++++++- gen_surv/competing_risks.py | 673 +++++++++++++++++++++++++++++ gen_surv/cphm.py | 99 ++++- gen_surv/interface.py | 42 +- gen_surv/summary.py | 496 +++++++++++++++++++++ gen_surv/visualization.py | 373 ++++++++++++++++ pyproject.toml | 41 +- tests/test_aft.py | 244 ++++++++++- tests/test_competing_risks.py | 212 +++++++++ tests/test_cphm.py | 42 +- 16 files changed, 2896 insertions(+), 90 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 examples/run_aft_weibull.py create mode 100644 examples/run_competing_risks.py create mode 100644 gen_surv/competing_risks.py create mode 100644 gen_surv/summary.py create mode 100644 gen_surv/visualization.py create mode 100644 tests/test_competing_risks.py diff --git a/.github/workflows/bump-version.yml b/.github/workflows/bump-version.yml index 5b41bd7..841b320 100644 --- a/.github/workflows/bump-version.yml +++ b/.github/workflows/bump-version.yml @@ -22,7 +22,11 @@ jobs: with: python-version: "3.11" - - run: pip install python-semantic-release + - name: Install Poetry + run: pip install poetry + + - name: Install python-semantic-release + run: pip install python-semantic-release - name: Configure Git run: | @@ -32,36 +36,10 @@ jobs: - name: Run Semantic Release env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: semantic-release version + run: | + # Run semantic-release to get the next version + semantic-release version - name: Push changes run: | git push --follow-tags - - name: "Install Poetry" - run: pip install poetry - - name: "Determine version bump type" - run: | - git fetch --tags - # This defaults to a patch type, unless a feature commit was pushed, then set type to minor - LAST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1)) - LAST_COMMIT=$(git log -1 --format='%H') - echo "Last git tag: $LAST_TAG" - echo "Last git commit: $LAST_COMMIT" - echo "Commits:" - git log --no-merges --pretty=oneline $LAST_TAG...$LAST_COMMIT - git log --no-merges --pretty=format:"%s" $LAST_TAG...$LAST_COMMIT | grep -q ^feat: && BUMP_TYPE="minor" || BUMP_TYPE="patch" - echo "Version bump type: $BUMP_TYPE" - echo "BUMP_TYPE=$BUMP_TYPE" >> $GITHUB_ENV - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: "Version bump" - run: | - poetry version $BUMP_TYPE - - name: "Push new version" - run: | - git add pyproject.toml - git commit -m "Update version to $(poetry version -s)" - git pull --ff-only origin main - git push origin main --follow-tags - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..996e268 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,74 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + name: Test with Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install + + - name: Run tests with coverage + run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + files: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} # optional if public repo + + lint: + name: Code Quality + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install + + - name: Run black + run: poetry run black --check gen_surv tests examples + + - name: Run isort + run: poetry run isort --check gen_surv tests examples + + - name: Run flake8 + run: poetry run flake8 gen_surv tests examples + + - name: Run mypy + run: poetry run mypy gen_surv diff --git a/examples/run_aft_weibull.py b/examples/run_aft_weibull.py new file mode 100644 index 0000000..7bb8abf --- /dev/null +++ b/examples/run_aft_weibull.py @@ -0,0 +1,97 @@ +""" +Example demonstrating Weibull AFT model and visualization capabilities. +""" + +import sys +import os +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate +from gen_surv.visualization import ( + plot_survival_curve, + plot_hazard_comparison, + plot_covariate_effect, + describe_survival +) + +# 1. Generate data from different models for comparison +models = { + "Weibull AFT (shape=0.5)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=0.5, # Decreasing hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ), + "Weibull AFT (shape=1.0)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=1.0, # Constant hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ), + "Weibull AFT (shape=2.0)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=2.0, # Increasing hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) +} + +# Print sample data +print("Sample data from Weibull AFT model (shape=2.0):") +print(models["Weibull AFT (shape=2.0)"].head()) +print("\n") + +# 2. Compare survival curves from different models +fig1, ax1 = plot_survival_curve( + data=pd.concat( + [df.assign(_model=name) for name, df in models.items()] + ), + group_col="_model", + title="Comparing Survival Curves with Different Weibull Shapes" +) +plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight") + +# 3. Compare hazard functions +fig2, ax2 = plot_hazard_comparison( + models=models, + title="Comparing Hazard Functions with Different Weibull Shapes" +) +plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight") + +# 4. Visualize covariate effect on survival +fig3, ax3 = plot_covariate_effect( + data=models["Weibull AFT (shape=2.0)"], + covariate_col="X0", + n_groups=3, + title="Effect of X0 Covariate on Survival" +) +plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight") + +# 5. Summary statistics +for name, df in models.items(): + print(f"Summary for {name}:") + summary = describe_survival(df) + print(summary) + print("\n") + +print("Plots saved to current directory.") + +# Show plots if running interactively +if __name__ == "__main__": + plt.show() diff --git a/examples/run_competing_risks.py b/examples/run_competing_risks.py new file mode 100644 index 0000000..81c163c --- /dev/null +++ b/examples/run_competing_risks.py @@ -0,0 +1,138 @@ +""" +Example demonstrating the Competing Risks models and visualization. +""" + +import sys +import os +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate +from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull, cause_specific_cumulative_incidence +from gen_surv.summary import summarize_survival_dataset, compare_survival_datasets + + +def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)): + """Plot the cause-specific cumulative incidence functions.""" + if time_points is None: + max_time = df["time"].max() + time_points = np.linspace(0, max_time, 100) + + # Get unique causes (excluding censoring) + causes = sorted([c for c in df["status"].unique() if c > 0]) + + # Create the plot + fig, ax = plt.subplots(figsize=figsize) + + for cause in causes: + cif = cause_specific_cumulative_incidence(df, time_points, cause=cause) + ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}") + + # Add overlay showing number of subjects at each time + time_bins = np.linspace(0, df["time"].max(), 10) + event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0] + + # Add a secondary y-axis for event counts + ax2 = ax.twinx() + ax2.bar(time_bins[:-1], event_counts, width=time_bins[1]-time_bins[0], + alpha=0.2, color='gray', align='edge') + ax2.set_ylabel('Number of events') + ax2.grid(False) + + # Format the main plot + ax.set_xlabel("Time") + ax.set_ylabel("Cumulative Incidence") + ax.set_title("Cause-Specific Cumulative Incidence Functions") + ax.legend() + ax.grid(alpha=0.3) + + return fig, ax + + +# 1. Generate data with 2 competing risks +print("Generating data with exponential hazards...") +data_exponential = gen_competing_risks( + n=500, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42 +) + +# 2. Generate data with Weibull hazards (different shapes) +print("Generating data with Weibull hazards...") +data_weibull = gen_competing_risks_weibull( + n=500, + n_risks=2, + shape_params=[0.8, 1.5], # Decreasing vs increasing hazard + scale_params=[2.0, 3.0], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42 +) + +# 3. Print summary statistics for both datasets +print("\nSummary of Exponential Hazards dataset:") +summarize_survival_dataset(data_exponential) + +print("\nSummary of Weibull Hazards dataset:") +summarize_survival_dataset(data_weibull) + +# 4. Compare event distributions +print("\nEvent distribution (Exponential Hazards):") +print(data_exponential["status"].value_counts()) + +print("\nEvent distribution (Weibull Hazards):") +print(data_weibull["status"].value_counts()) + +# 5. Plot cause-specific cumulative incidence functions +print("\nPlotting cumulative incidence functions...") +time_points = np.linspace(0, 5, 100) + +fig1, ax1 = plot_cause_specific_cumulative_incidence( + data_exponential, + time_points=time_points, + figsize=(10, 6) +) +plt.title("Cumulative Incidence Functions (Exponential Hazards)") +plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight") + +fig2, ax2 = plot_cause_specific_cumulative_incidence( + data_weibull, + time_points=time_points, + figsize=(10, 6) +) +plt.title("Cumulative Incidence Functions (Weibull Hazards)") +plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight") + +# 6. Demonstrate using the unified generate() interface +print("\nUsing the unified generate() interface:") +data_unified = generate( + model="competing_risks", + n=100, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42 +) +print(data_unified.head()) + +# 7. Compare datasets +print("\nComparing datasets:") +comparison = compare_survival_datasets({ + "Exponential": data_exponential, + "Weibull": data_weibull +}) +print(comparison) + +# Show plots if running interactively +if __name__ == "__main__": + plt.show() diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 8939886..d300653 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -1,17 +1,66 @@ """Top-level package for ``gen_surv``. -This module exposes the :func:`generate` function and provides access to the -package version via ``__version__``. +This module exposes the main functions and provides access to the package version. """ from importlib.metadata import PackageNotFoundError, version +# Main interface from .interface import generate +# Individual generators +from .cphm import gen_cphm +from .cmm import gen_cmm +from .tdcm import gen_tdcm +from .thmm import gen_thmm +from .aft import gen_aft_log_normal, gen_aft_weibull + +# Helper functions +from .bivariate import sample_bivariate_distribution +from .censoring import runifcens, rexpocens + +# Visualization tools (requires matplotlib and lifelines) +try: + from .visualization import ( + plot_survival_curve, + plot_hazard_comparison, + plot_covariate_effect, + describe_survival + ) + _has_visualization = True +except ImportError: + _has_visualization = False + try: __version__ = version("gen_surv") except PackageNotFoundError: # pragma: no cover - fallback when package not installed __version__ = "0.0.0" -__all__ = ["generate", "__version__"] +__all__ = [ + # Main interface + "generate", + "__version__", + + # Individual generators + "gen_cphm", + "gen_cmm", + "gen_tdcm", + "gen_thmm", + "gen_aft_log_normal", + "gen_aft_weibull", + + # Helpers + "sample_bivariate_distribution", + "runifcens", + "rexpocens", +] +# Add visualization tools to __all__ if available +if _has_visualization: + __all__.extend([ + "plot_survival_curve", + "plot_hazard_comparison", + "plot_covariate_effect", + "describe_survival" + ]) + \ No newline at end of file diff --git a/gen_surv/aft.py b/gen_surv/aft.py index 5c85fb9..b3dedb9 100644 --- a/gen_surv/aft.py +++ b/gen_surv/aft.py @@ -1,21 +1,42 @@ +""" +Accelerated Failure Time (AFT) models including Weibull, Log-Normal, and Log-Logistic distributions. +""" + import numpy as np import pandas as pd +from typing import List, Optional, Literal -def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed=None): +def gen_aft_log_normal( + n: int, + beta: List[float], + sigma: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None +) -> pd.DataFrame: """ Simulate survival data under a Log-Normal Accelerated Failure Time (AFT) model. - Parameters: - - n (int): Number of individuals - - beta (list of float): Coefficients for covariates - - sigma (float): Standard deviation of the log-error term - - model_cens (str): 'uniform' or 'exponential' - - cens_par (float): Parameter for censoring distribution - - seed (int, optional): Random seed + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + sigma : float + Standard deviation of the log-error term + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility - Returns: - - pd.DataFrame: DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] """ if seed is not None: np.random.seed(seed) @@ -46,3 +67,174 @@ def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed=None): data[f"X{j}"] = X[:, j] return data + + +def gen_aft_weibull( + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Simulate survival data under a Weibull Accelerated Failure Time (AFT) model. + + The Weibull AFT model has survival function: + S(t|X) = exp(-(t/scale)^shape * exp(-X*beta)) + + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + shape : float + Weibull shape parameter (k > 0) + scale : float + Weibull scale parameter (λ > 0) + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility + + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + """ + if seed is not None: + np.random.seed(seed) + + if shape <= 0: + raise ValueError("shape parameter must be positive") + + if scale <= 0: + raise ValueError("scale parameter must be positive") + + p = len(beta) + X = np.random.normal(size=(n, p)) + + # Linear predictor + eta = X @ np.array(beta) + + # Generate Weibull survival times + U = np.random.uniform(size=n) + T = scale * (-np.log(U) * np.exp(-eta))**(1/shape) + + # Generate censoring times + if model_cens == "uniform": + C = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + C = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Observed time is the minimum of event time and censoring time + observed_time = np.minimum(T, C) + status = (T <= C).astype(int) + + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_time, + "status": status + }) + + for j in range(p): + data[f"X{j}"] = X[:, j] + + return data + + +def gen_aft_log_logistic( + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Simulate survival data under a Log-Logistic Accelerated Failure Time (AFT) model. + + The Log-Logistic AFT model has survival function: + S(t|X) = 1 / (1 + (t/scale)^shape * exp(X*beta)) + + Log-logistic distribution is useful when the hazard rate first increases and then decreases. + + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + shape : float + Log-logistic shape parameter (α > 0) + scale : float + Log-logistic scale parameter (β > 0) + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility + + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + """ + if seed is not None: + np.random.seed(seed) + + if shape <= 0: + raise ValueError("shape parameter must be positive") + + if scale <= 0: + raise ValueError("scale parameter must be positive") + + p = len(beta) + X = np.random.normal(size=(n, p)) + + # Linear predictor + eta = X @ np.array(beta) + + # Generate Log-Logistic survival times + U = np.random.uniform(size=n) + + # Inverse CDF method: S(t) = 1/(1 + (t/scale)^shape) + # so t = scale * (1/S - 1)^(1/shape) + # For random U ~ Uniform(0,1), we can use U as 1-S + # t = scale * (1/(1-U) - 1)^(1/shape) * exp(-eta/shape) + # simplifies to: t = scale * (U/(1-U))^(1/shape) * exp(-eta/shape) + + # Avoid numerical issues near 1 + U = np.clip(U, 0.001, 0.999) + T = scale * (U / (1 - U))**(1/shape) * np.exp(-eta/shape) + + # Generate censoring times + if model_cens == "uniform": + C = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + C = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Observed time is the minimum of event time and censoring time + observed_time = np.minimum(T, C) + status = (T <= C).astype(int) + + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_time, + "status": status + }) + + for j in range(p): + data[f"X{j}"] = X[:, j] + + return data diff --git a/gen_surv/cli.py b/gen_surv/cli.py index 542ea51..d64b00a 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -1,36 +1,166 @@ -import csv -from typing import Optional +""" +Command-line interface for gen_surv. + +This module provides a command-line interface for generating survival data +using the gen_surv package. +""" + +from typing import Optional, List, Tuple import typer from gen_surv.interface import generate app = typer.Typer(help="Generate synthetic survival datasets.") + @app.command() def dataset( model: str = typer.Argument( - ..., help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln]" + ..., + help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull]" ), n: int = typer.Option(100, help="Number of samples"), + model_cens: str = typer.Option( + "uniform", help="Censoring model: 'uniform' or 'exponential'" + ), + cens_par: float = typer.Option(1.0, help="Censoring parameter"), + beta: List[float] = typer.Option( + [0.5], help="Regression coefficient(s). Provide multiple values for multi-parameter models." + ), + covar: Optional[float] = typer.Option( + 2.0, help="Covariate range (for CPHM, CMM, THMM)" + ), + sigma: Optional[float] = typer.Option( + 1.0, help="Standard deviation parameter (for log-normal AFT)" + ), + shape: Optional[float] = typer.Option( + 1.5, help="Shape parameter (for Weibull AFT)" + ), + scale: Optional[float] = typer.Option( + 2.0, help="Scale parameter (for Weibull AFT)" + ), + seed: Optional[int] = typer.Option( + None, help="Random seed for reproducibility" + ), output: Optional[str] = typer.Option( None, "-o", help="Output CSV file. Prints to stdout if omitted." ), ) -> None: """Generate survival data and optionally save to CSV. - Args: - model: Identifier of the generator to use. - n: Number of samples to create. - output: Optional path to save the CSV file. + Examples: + # Generate data from CPHM model + $ gen_surv dataset cphm --n 100 --beta 0.5 --covar 2.0 -o cphm_data.csv - Returns: - None + # Generate data from Weibull AFT model + $ gen_surv dataset aft_weibull --n 200 --beta 0.5 --beta -0.3 --shape 1.5 --scale 2.0 -o aft_data.csv """ - df = generate(model=model, n=n) + # Prepare arguments based on the selected model + kwargs = { + "model": model, + "n": n, + "model_cens": model_cens, + "cens_par": cens_par, + "seed": seed + } + + # Add model-specific parameters + if model in ["cphm", "cmm", "thmm"]: + # These models use a single beta and covar + kwargs["beta"] = beta[0] if len(beta) > 0 else 0.5 + kwargs["covar"] = covar + + elif model == "aft_ln": + # Log-normal AFT model uses beta list and sigma + kwargs["beta"] = beta + kwargs["sigma"] = sigma + + elif model == "aft_weibull": + # Weibull AFT model uses beta list, shape, and scale + kwargs["beta"] = beta + kwargs["shape"] = shape + kwargs["scale"] = scale + + # Generate the data + df = generate(**kwargs) + + # Output the data if output: df.to_csv(output, index=False) typer.echo(f"Saved dataset to {output}") else: typer.echo(df.to_csv(index=False)) + +@app.command() +def visualize( + input_file: str = typer.Argument( + ..., help="Input CSV file containing survival data" + ), + time_col: str = typer.Option( + "time", help="Column containing time/duration values" + ), + status_col: str = typer.Option( + "status", help="Column containing event indicator (1=event, 0=censored)" + ), + group_col: Optional[str] = typer.Option( + None, help="Column to use for stratification" + ), + output: str = typer.Option( + "survival_plot.png", help="Output image file" + ), +) -> None: + """Visualize survival data from a CSV file. + + Examples: + # Generate a Kaplan-Meier plot from a CSV file + $ gen_surv visualize data.csv --time-col time --status-col status -o km_plot.png + + # Generate a stratified plot using a grouping variable + $ gen_surv visualize data.csv --group-col X0 -o stratified_plot.png + """ + try: + import pandas as pd + from gen_surv.visualization import plot_survival_curve + import matplotlib.pyplot as plt + except ImportError: + typer.echo( + "Error: Visualization requires matplotlib and lifelines. " + "Install them with: pip install matplotlib lifelines" + ) + raise typer.Exit(1) + + # Load the data + try: + data = pd.read_csv(input_file) + except Exception as e: + typer.echo(f"Error loading CSV file: {str(e)}") + raise typer.Exit(1) + + # Check required columns + if time_col not in data.columns: + typer.echo(f"Error: Time column '{time_col}' not found in data") + raise typer.Exit(1) + + if status_col not in data.columns: + typer.echo(f"Error: Status column '{status_col}' not found in data") + raise typer.Exit(1) + + if group_col is not None and group_col not in data.columns: + typer.echo(f"Error: Group column '{group_col}' not found in data") + raise typer.Exit(1) + + # Create the plot + fig, ax = plot_survival_curve( + data=data, + time_col=time_col, + status_col=status_col, + group_col=group_col + ) + + # Save the plot + plt.savefig(output, dpi=300, bbox_inches="tight") + typer.echo(f"Plot saved to {output}") + + if __name__ == "__main__": app() diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py new file mode 100644 index 0000000..239d327 --- /dev/null +++ b/gen_surv/competing_risks.py @@ -0,0 +1,673 @@ +""" +Competing Risks models for survival data simulation. + +This module provides functions to generate survival data with +competing risks under different hazard specifications. +""" + +import numpy as np +import pandas as pd +from typing import Dict, List, Optional, Tuple, Union, Literal + + +def gen_competing_risks( + n: int, + n_risks: int = 2, + baseline_hazards: Optional[Union[List[float], np.ndarray]] = None, + betas: Optional[Union[List[List[float]], np.ndarray]] = None, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, + max_time: Optional[float] = 10.0, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Generate survival data with competing risks. + + Parameters + ---------- + n : int + Number of subjects. + n_risks : int, default=2 + Number of competing risks. + baseline_hazards : list of float or array, optional + Baseline hazard rates for each risk. If None, uses [0.5, 0.3, ...] + with decreasing values for subsequent risks. + betas : list of list of float or array, optional + Coefficients for covariates, one list per risk. + Shape should be (n_risks, n_covariates). + If None, generates random coefficients. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (0=censored, 1,2,...=competing events) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks + >>> + >>> # Simple example with 2 competing risks + >>> df = gen_competing_risks( + ... n=100, + ... n_risks=2, + ... baseline_hazards=[0.5, 0.3], + ... betas=[[0.8, -0.5], [0.2, 0.7]], + ... seed=42 + ... ) + >>> + >>> # Distribution of event types + >>> df["status"].value_counts() + """ + if seed is not None: + np.random.seed(seed) + + # Set default baseline hazards if not provided + if baseline_hazards is None: + baseline_hazards = np.array([0.5 / (i + 1) for i in range(n_risks)]) + else: + baseline_hazards = np.array(baseline_hazards) + if len(baseline_hazards) != n_risks: + raise ValueError(f"Expected {n_risks} baseline hazards, got {len(baseline_hazards)}") + + # Set default number of covariates and their parameters + n_covariates = 2 # Default number of covariates + + # Set default covariate parameters if not provided + if covariate_params is None: + if covariate_dist == "normal": + covariate_params = {"mean": 0.0, "std": 1.0} + elif covariate_dist == "uniform": + covariate_params = {"low": 0.0, "high": 1.0} + elif covariate_dist == "binary": + covariate_params = {"p": 0.5} + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Set default betas if not provided + if betas is None: + betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) + else: + betas = np.array(betas) + if betas.shape[0] != n_risks: + raise ValueError(f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}") + n_covariates = betas.shape[1] + + # Generate covariates + if covariate_dist == "normal": + X = np.random.normal( + covariate_params.get("mean", 0.0), + covariate_params.get("std", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "uniform": + X = np.random.uniform( + covariate_params.get("low", 0.0), + covariate_params.get("high", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "binary": + X = np.random.binomial( + 1, + covariate_params.get("p", 0.5), + size=(n, n_covariates) + ) + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Calculate linear predictors for each risk + linear_predictors = np.zeros((n, n_risks)) + for j in range(n_risks): + linear_predictors[:, j] = X @ betas[j] + + # Calculate hazard rates + hazard_rates = np.zeros_like(linear_predictors) + for j in range(n_risks): + hazard_rates[:, j] = baseline_hazards[j] * np.exp(linear_predictors[:, j]) + + # Generate event times for each risk + event_times = np.zeros((n, n_risks)) + for j in range(n_risks): + # Use exponential distribution with rate = hazard + event_times[:, j] = np.random.exponential(1 / hazard_rates[:, j]) + + # Generate censoring times + if model_cens == "uniform": + cens_times = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + cens_times = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Find the minimum time for each subject (first event or censoring) + min_event_times = np.min(event_times, axis=1) + observed_times = np.minimum(min_event_times, cens_times) + + # Determine event type (0 = censored, 1...n_risks = event type) + status = np.zeros(n, dtype=int) + for i in range(n): + if min_event_times[i] <= cens_times[i]: + # Find which risk occurred first + risk_index = np.argmin(event_times[i]) + status[i] = risk_index + 1 # 1-based indexing for event types + + # Cap times at max_time if specified + if max_time is not None: + over_max = observed_times > max_time + observed_times[over_max] = max_time + status[over_max] = 0 # Censored if beyond max_time + + # Create DataFrame + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_times, + "status": status + }) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def gen_competing_risks_weibull( + n: int, + n_risks: int = 2, + shape_params: Optional[Union[List[float], np.ndarray]] = None, + scale_params: Optional[Union[List[float], np.ndarray]] = None, + betas: Optional[Union[List[List[float]], np.ndarray]] = None, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, + max_time: Optional[float] = 10.0, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Generate survival data with competing risks using Weibull hazards. + + Parameters + ---------- + n : int + Number of subjects. + n_risks : int, default=2 + Number of competing risks. + shape_params : list of float or array, optional + Shape parameters for Weibull distribution, one per risk. + If None, uses [1.2, 0.8, ...] alternating values. + scale_params : list of float or array, optional + Scale parameters for Weibull distribution, one per risk. + If None, uses [2.0, 3.0, ...] increasing values. + betas : list of list of float or array, optional + Coefficients for covariates, one list per risk. + Shape should be (n_risks, n_covariates). + If None, generates random coefficients. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (0=censored, 1,2,...=competing events) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks_weibull + >>> + >>> # Example with 2 competing risks with different shapes + >>> df = gen_competing_risks_weibull( + ... n=100, + ... n_risks=2, + ... shape_params=[0.8, 1.5], # Decreasing vs increasing hazard + ... scale_params=[2.0, 3.0], + ... betas=[[0.8, -0.5], [0.2, 0.7]], + ... seed=42 + ... ) + """ + if seed is not None: + np.random.seed(seed) + + # Set default shape and scale parameters if not provided + if shape_params is None: + shape_params = np.array([1.2 if i % 2 == 0 else 0.8 for i in range(n_risks)]) + else: + shape_params = np.array(shape_params) + if len(shape_params) != n_risks: + raise ValueError(f"Expected {n_risks} shape parameters, got {len(shape_params)}") + + if scale_params is None: + scale_params = np.array([2.0 + i for i in range(n_risks)]) + else: + scale_params = np.array(scale_params) + if len(scale_params) != n_risks: + raise ValueError(f"Expected {n_risks} scale parameters, got {len(scale_params)}") + + # Set default number of covariates and their parameters + n_covariates = 2 # Default number of covariates + + # Set default covariate parameters if not provided + if covariate_params is None: + if covariate_dist == "normal": + covariate_params = {"mean": 0.0, "std": 1.0} + elif covariate_dist == "uniform": + covariate_params = {"low": 0.0, "high": 1.0} + elif covariate_dist == "binary": + covariate_params = {"p": 0.5} + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Set default betas if not provided + if betas is None: + betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) + else: + betas = np.array(betas) + if betas.shape[0] != n_risks: + raise ValueError(f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}") + n_covariates = betas.shape[1] + + # Generate covariates + if covariate_dist == "normal": + X = np.random.normal( + covariate_params.get("mean", 0.0), + covariate_params.get("std", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "uniform": + X = np.random.uniform( + covariate_params.get("low", 0.0), + covariate_params.get("high", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "binary": + X = np.random.binomial( + 1, + covariate_params.get("p", 0.5), + size=(n, n_covariates) + ) + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Calculate linear predictors for each risk + linear_predictors = np.zeros((n, n_risks)) + for j in range(n_risks): + linear_predictors[:, j] = X @ betas[j] + + # Generate event times for each risk using Weibull distribution + event_times = np.zeros((n, n_risks)) + for j in range(n_risks): + # Adjust the scale parameter using the linear predictor + adjusted_scale = scale_params[j] * np.exp(-linear_predictors[:, j] / shape_params[j]) + + # Generate random uniform between 0 and 1 + u = np.random.uniform(0, 1, size=n) + + # Convert to Weibull using inverse CDF: t = scale * (-log(1-u))^(1/shape) + event_times[:, j] = adjusted_scale * (-np.log(1 - u)) ** (1 / shape_params[j]) + + # Generate censoring times + if model_cens == "uniform": + cens_times = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + cens_times = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Find the minimum time for each subject (first event or censoring) + min_event_times = np.min(event_times, axis=1) + observed_times = np.minimum(min_event_times, cens_times) + + # Determine event type (0 = censored, 1...n_risks = event type) + status = np.zeros(n, dtype=int) + for i in range(n): + if min_event_times[i] <= cens_times[i]: + # Find which risk occurred first + risk_index = np.argmin(event_times[i]) + status[i] = risk_index + 1 # 1-based indexing for event types + + # Cap times at max_time if specified + if max_time is not None: + over_max = observed_times > max_time + observed_times[over_max] = max_time + status[over_max] = 0 # Censored if beyond max_time + + # Create DataFrame + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_times, + "status": status + }) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def cause_specific_cumulative_incidence( + data: pd.DataFrame, + time_points: Union[List[float], np.ndarray], + time_col: str = "time", + status_col: str = "status", + cause: int = 1 +) -> pd.DataFrame: + """ + Calculate the cause-specific cumulative incidence function at specified time points. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_points : list of float or array + Time points at which to calculate the cumulative incidence. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + cause : int, default=1 + The cause/event type for which to calculate the incidence. + + Returns + ------- + pd.DataFrame + DataFrame with time points and corresponding cumulative incidence values. + + Notes + ----- + The cumulative incidence function for cause j is defined as: + F_j(t) = P(T <= t, cause = j) + + This is the probability of experiencing the event of type j before time t. + """ + # Validate the cause value + unique_causes = set(data[status_col].unique()) - {0} # Exclude censoring + if cause not in unique_causes: + raise ValueError(f"Cause {cause} not found in the data. Available causes: {unique_causes}") + + # Sort data by time + sorted_data = data.sort_values(by=time_col).copy() + + # Initialize arrays for calculations + times = sorted_data[time_col].values + status = sorted_data[status_col].values + n = len(times) + + # Calculate the survival function (probability of no event of any type) + survival = np.ones(n) + cumulative_incidence = np.zeros(n) + + for i in range(n): + if i > 0: + survival[i] = survival[i-1] + cumulative_incidence[i] = cumulative_incidence[i-1] + + # Count subjects at risk at this time + at_risk = n - i + + if status[i] > 0: # Any event + # Update overall survival + survival[i] *= (1 - 1/at_risk) + + # Update cause-specific cumulative incidence + if status[i] == cause: + prev_survival = survival[i-1] if i > 0 else 1.0 + cumulative_incidence[i] += prev_survival * (1/at_risk) + + # Interpolate values at the requested time points + result = [] + for t in time_points: + if t <= 0: + result.append({"time": t, "incidence": 0.0}) + elif t >= max(times): + result.append({"time": t, "incidence": cumulative_incidence[-1]}) + else: + # Find the index where time >= t + idx = np.searchsorted(times, t) + result.append({"time": t, "incidence": cumulative_incidence[idx-1]}) + + return pd.DataFrame(result) + + +def competing_risks_summary( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + covariate_cols: Optional[List[str]] = None +) -> Dict[str, Any]: + """ + Provide a summary of a competing risks dataset. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + covariate_cols : list of str, optional + List of covariate columns to include in the summary. + If None, all columns except time_col and status_col are considered. + + Returns + ------- + Dict[str, Any] + Dictionary with summary statistics. + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks, competing_risks_summary + >>> + >>> # Generate data + >>> df = gen_competing_risks(n=100, n_risks=3, seed=42) + >>> + >>> # Get summary + >>> summary = competing_risks_summary(df) + >>> print(f"Number of events by cause: {summary['events_by_cause']}") + >>> print(f"Median time to first event: {summary['median_time']}") + """ + # Determine covariate columns if not provided + if covariate_cols is None: + covariate_cols = [col for col in data.columns + if col not in [time_col, status_col, "id"]] + + # Basic counts + n_subjects = len(data) + n_events = (data[status_col] > 0).sum() + n_censored = n_subjects - n_events + censoring_rate = n_censored / n_subjects + + # Events by cause + causes = sorted(data[data[status_col] > 0][status_col].unique()) + events_by_cause = {} + for cause in causes: + n_cause = (data[status_col] == cause).sum() + events_by_cause[int(cause)] = { + "count": int(n_cause), + "proportion": float(n_cause / n_subjects), + "proportion_of_events": float(n_cause / n_events) if n_events > 0 else 0 + } + + # Time statistics + time_stats = { + "min": float(data[time_col].min()), + "max": float(data[time_col].max()), + "median": float(data[time_col].median()), + "mean": float(data[time_col].mean()) + } + + # Median time to each type of event + median_time_by_cause = {} + for cause in causes: + cause_times = data[data[status_col] == cause][time_col] + if not cause_times.empty: + median_time_by_cause[int(cause)] = float(cause_times.median()) + + # Covariate statistics + covariate_stats = {} + for col in covariate_cols: + col_data = data[col] + + # Check if numeric + if pd.api.types.is_numeric_dtype(col_data): + covariate_stats[col] = { + "mean": float(col_data.mean()), + "median": float(col_data.median()), + "std": float(col_data.std()), + "min": float(col_data.min()), + "max": float(col_data.max()) + } + else: + # Categorical statistics + value_counts = col_data.value_counts(normalize=True).to_dict() + covariate_stats[col] = { + "categories": len(value_counts), + "distribution": {str(k): float(v) for k, v in value_counts.items()} + } + + # Compile final summary + summary = { + "n_subjects": n_subjects, + "n_events": n_events, + "n_censored": n_censored, + "censoring_rate": censoring_rate, + "n_causes": len(causes), + "causes": list(map(int, causes)), + "events_by_cause": events_by_cause, + "time_stats": time_stats, + "median_time_by_cause": median_time_by_cause, + "covariate_stats": covariate_stats + } + + return summary + + +def plot_cause_specific_hazards( + data: pd.DataFrame, + time_points: Optional[np.ndarray] = None, + time_col: str = "time", + status_col: str = "status", + bandwidth: float = 0.5, + figsize: Tuple[float, float] = (10, 6) +) -> Tuple[plt.Figure, plt.Axes]: + """ + Plot cause-specific hazard functions. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_points : array, optional + Time points at which to estimate hazards. + If None, uses 100 equally spaced points from 0 to max time. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + bandwidth : float, default=0.5 + Bandwidth for kernel density estimation. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + + Returns + ------- + tuple + Figure and axes objects. + + Notes + ----- + This function requires matplotlib and scipy. + """ + try: + import matplotlib.pyplot as plt + from scipy.stats import gaussian_kde + except ImportError: + raise ImportError( + "This function requires matplotlib and scipy. " + "Install them with: pip install matplotlib scipy" + ) + + # Determine time points if not provided + if time_points is None: + max_time = data[time_col].max() + time_points = np.linspace(0, max_time, 100) + + # Get unique causes (excluding censoring) + causes = sorted([c for c in data[status_col].unique() if c > 0]) + + # Create plot + fig, ax = plt.subplots(figsize=figsize) + + # Plot hazard for each cause + for cause in causes: + # Filter data for this cause + cause_data = data[data[status_col] == cause] + + if len(cause_data) < 5: # Skip if too few events + continue + + # Estimate hazard using kernel density + kde = gaussian_kde(cause_data[time_col], bw_method=bandwidth) + + # Calculate hazard rate + at_risk = np.array([ + len(data[data[time_col] >= t]) for t in time_points + ]) + + # Avoid division by zero + at_risk = np.maximum(at_risk, 1) + + # Hazard = density / survival + hazard = kde(time_points) * len(data) / at_risk + + # Plot + ax.plot(time_points, hazard, label=f"Cause {cause}") + + # Format plot + ax.set_xlabel("Time") + ax.set_ylabel("Hazard Rate") + ax.set_title("Cause-Specific Hazard Functions") + ax.legend() + ax.grid(alpha=0.3) + + return fig, ax diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index 97f52a6..9605474 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -1,22 +1,51 @@ +""" +Cox Proportional Hazards Model (CPHM) data generation. + +This module provides functions to generate survival data following the +Cox Proportional Hazards Model with various censoring mechanisms. +""" + import numpy as np import pandas as pd +from typing import Callable, Literal, Optional + from gen_surv.validate import validate_gen_cphm_inputs from gen_surv.censoring import runifcens, rexpocens -def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): +def generate_cphm_data( + n: int, + rfunc: Callable[[int, float], np.ndarray], + cens_par: float, + beta: float, + covariate_range: float, + seed: Optional[int] = None +) -> np.ndarray: """ Generate data from a Cox Proportional Hazards Model (CPHM). - Parameters: - - n (int): Number of samples to generate. - - rfunc (callable): Function to generate censoring times, must accept (size, cens_par). - - cens_par (float): Parameter passed to the censoring function. - - beta (float): Coefficient for the covariate. - - covar (float): Range for the covariate (uniformly sampled from [0, covar]). + Parameters + ---------- + n : int + Number of samples to generate. + rfunc : callable + Function to generate censoring times, must accept (size, cens_par). + cens_par : float + Parameter passed to the censoring function. + beta : float + Coefficient for the covariate. + covariate_range : float + Range for the covariate (uniformly sampled from [0, covar]). + seed : int, optional + Random seed for reproducibility. - Returns: - - np.ndarray: Array with shape (n, 3): [time, status, covariate] + Returns + ------- + np.ndarray + Array with shape (n, 3): [time, status, covariate] """ + if seed is not None: + np.random.seed(seed) + data = np.zeros((n, 3)) for k in range(n): @@ -32,19 +61,49 @@ def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): return data -def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: +def gen_cphm( + n: int, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + beta: float, + covar: float, + seed: Optional[int] = None +) -> pd.DataFrame: """ - Convenience wrapper to generate CPHM survival data. + Generate survival data following a Cox Proportional Hazards Model. + + Parameters + ---------- + n : int + Number of observations. + model_cens : {"uniform", "exponential"} + Type of censoring mechanism. + cens_par : float + Parameter for the censoring model. + beta : float + Coefficient for the covariate. + covar : float + Covariate range (uniform between 0 and covar). + seed : int, optional + Random seed for reproducibility. - Parameters: - - n (int): Number of observations. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Parameter for the censoring model. - - beta (float): Coefficient for the covariate. - - covar (float): Covariate range (uniform between 0 and covar). + Returns + ------- + pd.DataFrame + DataFrame with columns ["time", "status", "covariate"] + - time: observed event or censoring time + - status: event indicator (1=event, 0=censored) + - covariate: predictor variable - Returns: - - pd.DataFrame: Columns are ["time", "status", "covariate"] + Examples + -------- + >>> from gen_surv.cphm import gen_cphm + >>> df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df.head() + time status covariate + 0 0.23 1.0 1.42 + 1 0.78 0.0 0.89 + ... """ validate_gen_cphm_inputs(n, model_cens, cens_par, covar) @@ -53,6 +112,6 @@ def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float "exponential": rexpocens }[model_cens] - data = generate_cphm_data(n, rfunc, cens_par, beta, covar) + data = generate_cphm_data(n, rfunc, cens_par, beta, covar, seed) return pd.DataFrame(data, columns=["time", "status", "covariate"]) diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 549ad45..f3935d0 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -6,22 +6,30 @@ >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) """ -from typing import Any +from typing import Any, Dict, Literal, Optional, Union, List, Tuple, cast import pandas as pd from gen_surv.cphm import gen_cphm from gen_surv.cmm import gen_cmm from gen_surv.tdcm import gen_tdcm from gen_surv.thmm import gen_thmm -from gen_surv.aft import gen_aft_log_normal +from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic +from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull +# Type definitions for model names +ModelType = Literal["cphm", "cmm", "tdcm", "thmm", "aft_ln", "aft_weibull", "aft_log_logistic", "competing_risks", "competing_risks_weibull"] +# Map model names to their generator functions _model_map = { "cphm": gen_cphm, "cmm": gen_cmm, "tdcm": gen_tdcm, "thmm": gen_thmm, "aft_ln": gen_aft_log_normal, + "aft_weibull": gen_aft_weibull, + "aft_log_logistic": gen_aft_log_logistic, + "competing_risks": gen_competing_risks, + "competing_risks_weibull": gen_competing_risks_weibull, } @@ -30,14 +38,30 @@ def generate(model: str, **kwargs: Any) -> pd.DataFrame: Args: model: Name of the generator to run. Must be one of ``cphm``, ``cmm``, - ``tdcm``, ``thmm`` or ``aft_ln``. - **kwargs: Arguments forwarded to the chosen generator. + ``tdcm``, ``thmm``, ``aft_ln``, ``aft_weibull``, ``aft_log_logistic``, + ``competing_risks``, or ``competing_risks_weibull``. + **kwargs: Arguments forwarded to the chosen generator. These vary by model: + - cphm: n, model_cens, cens_par, beta, covar + - cmm: n, model_cens, cens_par, beta, covar, rate + - tdcm: n, dist, corr, dist_par, model_cens, cens_par, beta, lam + - thmm: n, model_cens, cens_par, beta, covar, rate + - aft_ln: n, beta, sigma, model_cens, cens_par, seed + - aft_weibull: n, beta, shape, scale, model_cens, cens_par, seed + - aft_log_logistic: n, beta, shape, scale, model_cens, cens_par, seed + - competing_risks: n, n_risks, baseline_hazards, betas, covariate_dist, etc. + - competing_risks_weibull: n, n_risks, shape_params, scale_params, betas, etc. Returns: - pd.DataFrame: Simulated survival data. + pd.DataFrame: Simulated survival data with columns specific to the chosen model. + All models include time/duration and status columns. + + Raises: + ValueError: If an unknown model name is provided. """ - model = model.lower() - if model not in _model_map: - raise ValueError(f"Unknown model '{model}'. Choose from {list(_model_map.keys())}.") + model_lower = model.lower() + if model_lower not in _model_map: + valid_models = list(_model_map.keys()) + raise ValueError(f"Unknown model '{model}'. Choose from {valid_models}.") - return _model_map[model](**kwargs) + # Call the appropriate generator function with the provided kwargs + return _model_map[model_lower](**kwargs) diff --git a/gen_surv/summary.py b/gen_surv/summary.py new file mode 100644 index 0000000..8c2a92e --- /dev/null +++ b/gen_surv/summary.py @@ -0,0 +1,496 @@ +""" +Utilities for summarizing and validating survival datasets. + +This module provides functions to summarize survival data, +check data quality, and identify potential issues. +""" + +from typing import Dict, List, Optional, Tuple, Any +import pandas as pd + + +def summarize_survival_dataset( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + id_col: Optional[str] = None, + covariate_cols: Optional[List[str]] = None, + verbose: bool = True +) -> Dict[str, Any]: + """ + Generate a comprehensive summary of a survival dataset. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing survival data. + time_col : str, default="time" + Name of the column containing time-to-event values. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + id_col : str, optional + Name of the column containing subject identifiers. + covariate_cols : list of str, optional + List of column names to include as covariates in the summary. + If None, all columns except time_col, status_col, and id_col are considered. + verbose : bool, default=True + Whether to print the summary to console. + + Returns + ------- + Dict[str, Any] + Dictionary containing all summary statistics. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import summarize_survival_dataset + >>> + >>> # Generate example data + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covar=2.0) + >>> + >>> # Summarize the dataset + >>> summary = summarize_survival_dataset(df) + """ + # Validate input columns + for col in [time_col, status_col]: + if col not in data.columns: + raise ValueError(f"Column '{col}' not found in data") + + if id_col is not None and id_col not in data.columns: + raise ValueError(f"ID column '{id_col}' not found in data") + + # Determine covariate columns + if covariate_cols is None: + exclude_cols = {time_col, status_col} + if id_col is not None: + exclude_cols.add(id_col) + covariate_cols = [col for col in data.columns if col not in exclude_cols] + else: + missing_cols = [col for col in covariate_cols if col not in data.columns] + if missing_cols: + raise ValueError(f"Covariate columns not found in data: {missing_cols}") + + # Basic dataset information + n_subjects = len(data) + if id_col is not None: + n_unique_ids = data[id_col].nunique() + else: + n_unique_ids = n_subjects + + # Event information + n_events = data[status_col].sum() + n_censored = n_subjects - n_events + event_rate = n_events / n_subjects + + # Time statistics + time_min = data[time_col].min() + time_max = data[time_col].max() + time_mean = data[time_col].mean() + time_median = data[time_col].median() + + # Data quality checks + n_missing_time = data[time_col].isna().sum() + n_missing_status = data[status_col].isna().sum() + n_negative_time = (data[time_col] < 0).sum() + n_invalid_status = data[~data[status_col].isin([0, 1])].shape[0] + + # Covariate summaries + covariate_stats = {} + for col in covariate_cols: + col_data = data[col] + is_numeric = pd.api.types.is_numeric_dtype(col_data) + + if is_numeric: + covariate_stats[col] = { + "type": "numeric", + "min": col_data.min(), + "max": col_data.max(), + "mean": col_data.mean(), + "median": col_data.median(), + "std": col_data.std(), + "missing": col_data.isna().sum(), + "unique_values": col_data.nunique() + } + else: + # Categorical/string + covariate_stats[col] = { + "type": "categorical", + "n_categories": col_data.nunique(), + "top_categories": col_data.value_counts().head(5).to_dict(), + "missing": col_data.isna().sum() + } + + # Compile the summary + summary = { + "dataset_info": { + "n_subjects": n_subjects, + "n_unique_ids": n_unique_ids, + "n_covariates": len(covariate_cols) + }, + "event_info": { + "n_events": n_events, + "n_censored": n_censored, + "event_rate": event_rate + }, + "time_info": { + "min": time_min, + "max": time_max, + "mean": time_mean, + "median": time_median + }, + "data_quality": { + "missing_time": n_missing_time, + "missing_status": n_missing_status, + "negative_time": n_negative_time, + "invalid_status": n_invalid_status, + "overall_quality": "good" if (n_missing_time + n_missing_status + n_negative_time + n_invalid_status) == 0 else "issues_detected" + }, + "covariates": covariate_stats + } + + # Print summary if requested + if verbose: + _print_summary(summary, time_col, status_col, id_col, covariate_cols) + + return summary + + +def check_survival_data_quality( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + id_col: Optional[str] = None, + min_time: float = 0.0, + max_time: Optional[float] = None, + status_values: Optional[List[int]] = None, + fix_issues: bool = False +) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Check for common issues in survival data and optionally fix them. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing survival data. + time_col : str, default="time" + Name of the column containing time-to-event values. + status_col : str, default="status" + Name of the column containing event indicators. + id_col : str, optional + Name of the column containing subject identifiers. + min_time : float, default=0.0 + Minimum acceptable value for time column. + max_time : float, optional + Maximum acceptable value for time column. + status_values : list of int, optional + List of valid status values. Default is [0, 1]. + fix_issues : bool, default=False + Whether to attempt fixing issues (returns a modified DataFrame). + + Returns + ------- + Tuple[pd.DataFrame, Dict[str, Any]] + Tuple containing (possibly fixed) DataFrame and issues report. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import check_survival_data_quality + >>> + >>> # Generate example data with some issues + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covar=2.0) + >>> # Introduce some issues + >>> df.loc[0, "time"] = np.nan + >>> df.loc[1, "status"] = 2 # Invalid status + >>> + >>> # Check and fix issues + >>> fixed_df, issues = check_survival_data_quality(df, fix_issues=True) + >>> print(issues) + """ + if status_values is None: + status_values = [0, 1] + + # Make a copy to avoid modifying the original + if fix_issues: + data = data.copy() + + # Initialize issues report + issues = { + "missing_data": { + "time": 0, + "status": 0, + "id": 0 if id_col else None + }, + "invalid_values": { + "negative_time": 0, + "excessive_time": 0, + "invalid_status": 0 + }, + "duplicates": { + "duplicate_rows": 0, + "duplicate_ids": 0 if id_col else None + }, + "modifications": { + "rows_dropped": 0, + "values_fixed": 0 + } + } + + # Check for missing values + issues["missing_data"]["time"] = data[time_col].isna().sum() + issues["missing_data"]["status"] = data[status_col].isna().sum() + if id_col: + issues["missing_data"]["id"] = data[id_col].isna().sum() + + # Check for invalid values + issues["invalid_values"]["negative_time"] = (data[time_col] < min_time).sum() + if max_time is not None: + issues["invalid_values"]["excessive_time"] = (data[time_col] > max_time).sum() + issues["invalid_values"]["invalid_status"] = data[~data[status_col].isin(status_values)].shape[0] + + # Check for duplicates + issues["duplicates"]["duplicate_rows"] = data.duplicated().sum() + if id_col: + issues["duplicates"]["duplicate_ids"] = data[id_col].duplicated().sum() + + # Fix issues if requested + if fix_issues: + original_rows = len(data) + modified_values = 0 + + # Handle missing values + data = data.dropna(subset=[time_col, status_col]) + + # Handle invalid values + if min_time > 0: + # Set negative or too small times to min_time + mask = data[time_col] < min_time + if mask.any(): + data.loc[mask, time_col] = min_time + modified_values += mask.sum() + + if max_time is not None: + # Cap excessively large times + mask = data[time_col] > max_time + if mask.any(): + data.loc[mask, time_col] = max_time + modified_values += mask.sum() + + # Fix invalid status values + mask = ~data[status_col].isin(status_values) + if mask.any(): + # Default to censored (0) for invalid status + data.loc[mask, status_col] = 0 + modified_values += mask.sum() + + # Remove duplicates + data = data.drop_duplicates() + + # Update modification counts + issues["modifications"]["rows_dropped"] = original_rows - len(data) + issues["modifications"]["values_fixed"] = modified_values + + return data, issues + + +def _print_summary( + summary: Dict[str, Any], + time_col: str, + status_col: str, + id_col: Optional[str], + covariate_cols: List[str] +) -> None: + """ + Print a formatted summary of survival data. + + Parameters + ---------- + summary : Dict[str, Any] + Summary dictionary from summarize_survival_dataset. + time_col : str + Name of the time column. + status_col : str + Name of the status column. + id_col : str, optional + Name of the ID column. + covariate_cols : List[str] + List of covariate column names. + """ + print("=" * 60) + print(f"SURVIVAL DATASET SUMMARY") + print("=" * 60) + + # Dataset info + print("\nDATASET INFORMATION:") + print(f" Subjects: {summary['dataset_info']['n_subjects']}") + if id_col: + print(f" Unique IDs: {summary['dataset_info']['n_unique_ids']}") + print(f" Covariates: {summary['dataset_info']['n_covariates']}") + + # Event info + print("\nEVENT INFORMATION:") + print(f" Events: {summary['event_info']['n_events']} " + + f"({summary['event_info']['event_rate']:.1%})") + print(f" Censored: {summary['event_info']['n_censored']} " + + f"({1 - summary['event_info']['event_rate']:.1%})") + + # Time info + print(f"\nTIME VARIABLE ({time_col}):") + print(f" Range: {summary['time_info']['min']:.2f} to {summary['time_info']['max']:.2f}") + print(f" Mean: {summary['time_info']['mean']:.2f}") + print(f" Median: {summary['time_info']['median']:.2f}") + + # Data quality + print("\nDATA QUALITY:") + quality_issues = ( + summary['data_quality']['missing_time'] + + summary['data_quality']['missing_status'] + + summary['data_quality']['negative_time'] + + summary['data_quality']['invalid_status'] + ) + + if quality_issues == 0: + print(" ✓ No issues detected") + else: + print(" ✗ Issues detected:") + if summary['data_quality']['missing_time'] > 0: + print(f" - Missing time values: {summary['data_quality']['missing_time']}") + if summary['data_quality']['missing_status'] > 0: + print(f" - Missing status values: {summary['data_quality']['missing_status']}") + if summary['data_quality']['negative_time'] > 0: + print(f" - Negative time values: {summary['data_quality']['negative_time']}") + if summary['data_quality']['invalid_status'] > 0: + print(f" - Invalid status values: {summary['data_quality']['invalid_status']}") + + # Covariates + print("\nCOVARIATES:") + if not covariate_cols: + print(" No covariates found") + else: + for col, stats in summary['covariates'].items(): + print(f" {col}:") + if stats['type'] == 'numeric': + print(f" Type: Numeric") + print(f" Range: {stats['min']:.2f} to {stats['max']:.2f}") + print(f" Mean: {stats['mean']:.2f}") + print(f" Missing: {stats['missing']}") + else: + print(f" Type: Categorical") + print(f" Categories: {stats['n_categories']}") + print(f" Missing: {stats['missing']}") + + print("\n" + "=" * 60) + + +def compare_survival_datasets( + datasets: Dict[str, pd.DataFrame], + time_col: str = "time", + status_col: str = "status", + covariate_cols: Optional[List[str]] = None +) -> pd.DataFrame: + """ + Compare multiple survival datasets and summarize their differences. + + Parameters + ---------- + datasets : Dict[str, pd.DataFrame] + Dictionary mapping dataset names to DataFrames. + time_col : str, default="time" + Name of the time column in each dataset. + status_col : str, default="status" + Name of the status column in each dataset. + covariate_cols : List[str], optional + List of covariate columns to compare. If None, compares all common columns. + + Returns + ------- + pd.DataFrame + Comparison table with datasets as columns and metrics as rows. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import compare_survival_datasets + >>> + >>> # Generate datasets with different parameters + >>> datasets = { + ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covar=2.0), + ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], + ... shape=1.5, scale=1.0, model_cens="uniform", cens_par=1.0) + ... } + >>> + >>> # Compare datasets + >>> comparison = compare_survival_datasets(datasets) + >>> print(comparison) + """ + if not datasets: + raise ValueError("No datasets provided for comparison") + + # Find common columns if covariate_cols not specified + if covariate_cols is None: + all_columns = [set(df.columns) for df in datasets.values()] + common_columns = set.intersection(*all_columns) + common_columns -= {time_col, status_col} # Remove time and status + covariate_cols = sorted(list(common_columns)) + + # Calculate summaries for each dataset + summaries = {} + for name, data in datasets.items(): + summaries[name] = summarize_survival_dataset( + data, time_col, status_col, + covariate_cols=covariate_cols, verbose=False + ) + + # Construct the comparison DataFrame + comparison_data = {} + + # Dataset info + comparison_data["n_subjects"] = { + name: summary["dataset_info"]["n_subjects"] + for name, summary in summaries.items() + } + comparison_data["n_events"] = { + name: summary["event_info"]["n_events"] + for name, summary in summaries.items() + } + comparison_data["event_rate"] = { + name: summary["event_info"]["event_rate"] + for name, summary in summaries.items() + } + + # Time info + comparison_data["time_min"] = { + name: summary["time_info"]["min"] + for name, summary in summaries.items() + } + comparison_data["time_max"] = { + name: summary["time_info"]["max"] + for name, summary in summaries.items() + } + comparison_data["time_mean"] = { + name: summary["time_info"]["mean"] + for name, summary in summaries.items() + } + comparison_data["time_median"] = { + name: summary["time_info"]["median"] + for name, summary in summaries.items() + } + + # Covariate info (means for numeric) + for col in covariate_cols: + for name, summary in summaries.items(): + if col in summary["covariates"]: + col_stats = summary["covariates"][col] + if col_stats["type"] == "numeric": + if f"{col}_mean" not in comparison_data: + comparison_data[f"{col}_mean"] = {} + comparison_data[f"{col}_mean"][name] = col_stats["mean"] + + # Create the DataFrame + comparison_df = pd.DataFrame(comparison_data).T + + return comparison_df diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py new file mode 100644 index 0000000..8700b30 --- /dev/null +++ b/gen_surv/visualization.py @@ -0,0 +1,373 @@ +""" +Visualization utilities for survival data. + +This module provides functions to visualize survival data generated by gen_surv, +including Kaplan-Meier survival curves and other commonly used plots in survival analysis. +""" + +from typing import Dict, List, Optional, Tuple, Union, Any +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes + + +def plot_survival_curve( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + group_col: Optional[str] = None, + confidence_intervals: bool = True, + title: str = "Kaplan-Meier Survival Curve", + figsize: Tuple[float, float] = (10, 6), + ci_alpha: float = 0.2, +) -> Tuple[Figure, Axes]: + """ + Plot Kaplan-Meier survival curves from simulated data. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + group_col : str, optional + Name of the column to use for stratification (creates separate curves). + confidence_intervals : bool, default=True + Whether to display confidence intervals around the survival curves. + title : str, default="Kaplan-Meier Survival Curve" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + ci_alpha : float, default=0.2 + Transparency level for confidence interval bands. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_survival_curve + >>> + >>> # Generate data + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> + >>> # Create a categorical group based on covariate + >>> df["group"] = pd.cut(df["covariate"], bins=2, labels=["Low", "High"]) + >>> + >>> # Plot survival curves by group + >>> fig, ax = plot_survival_curve(df, group_col="group") + >>> plt.show() + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import KaplanMeierFitter + from lifelines.plotting import add_at_risk_counts + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + fig, ax = plt.subplots(figsize=figsize) + + # Create separate KM curves for each group (if specified) + if group_col is not None: + groups = data[group_col].unique() + colors = plt.cm.tab10.colors[: len(groups)] + + for i, group in enumerate(groups): + mask = data[group_col] == group + group_data = data[mask] + + kmf = KaplanMeierFitter() + kmf.fit( + group_data[time_col], + group_data[status_col], + label=f"{group_col}={group}", + ) + + kmf.plot_survival_function( + ax=ax, ci_show=confidence_intervals, color=colors[i], ci_alpha=ci_alpha + ) + + # Add at-risk counts below the plot + add_at_risk_counts(kmf, ax=ax) + else: + # Single KM curve for all data + kmf = KaplanMeierFitter() + kmf.fit(data[time_col], data[status_col]) + + kmf.plot_survival_function( + ax=ax, ci_show=confidence_intervals, ci_alpha=ci_alpha + ) + + # Add at-risk counts below the plot + add_at_risk_counts(kmf, ax=ax) + + # Customize plot appearance + ax.set_title(title) + ax.set_xlabel("Time") + ax.set_ylabel("Survival Probability") + ax.grid(alpha=0.3) + ax.set_ylim(0, 1.05) + + plt.tight_layout() + return fig, ax + + +def plot_hazard_comparison( + models: Dict[str, pd.DataFrame], + time_col: str = "time", + status_col: str = "status", + title: str = "Hazard Function Comparison", + figsize: Tuple[float, float] = (10, 6), + bandwidth: float = 0.5, +) -> Tuple[Figure, Axes]: + """ + Compare hazard functions from multiple generated datasets. + + Parameters + ---------- + models : dict + Dictionary mapping model names to their respective DataFrames. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + title : str, default="Hazard Function Comparison" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + bandwidth : float, default=0.5 + Bandwidth parameter for kernel density estimation of the hazard function. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_hazard_comparison + >>> + >>> # Generate data from multiple models + >>> models = { + >>> "CPHM": generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0), + >>> "AFT Weibull": generate(model="aft_weibull", n=100, beta=[0.5], shape=1.5, scale=2.0, + >>> model_cens="uniform", cens_par=1.0) + >>> } + >>> + >>> # Compare hazard functions + >>> fig, ax = plot_hazard_comparison(models) + >>> plt.show() + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import NelsonAalenFitter + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + fig, ax = plt.subplots(figsize=figsize) + + for model_name, df in models.items(): + naf = NelsonAalenFitter() + naf.fit(df[time_col], df[status_col]) + + # Get smoothed hazard estimate + hazard = naf.smoothed_hazard_(bandwidth=bandwidth) + + # Plot hazard function + ax.plot(hazard.index, hazard.values, label=model_name, alpha=0.8) + + # Customize plot appearance + ax.set_title(title) + ax.set_xlabel("Time") + ax.set_ylabel("Hazard Rate") + ax.grid(alpha=0.3) + ax.legend() + + plt.tight_layout() + return fig, ax + + +def plot_covariate_effect( + data: pd.DataFrame, + covariate_col: str, + time_col: str = "time", + status_col: str = "status", + n_groups: int = 3, + title: str = "Effect of Covariate on Survival", + figsize: Tuple[float, float] = (10, 6), + ci_alpha: float = 0.2, +) -> Tuple[Figure, Axes]: + """ + Visualize the effect of a continuous covariate on survival by discretizing it. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + covariate_col : str + Name of the covariate column to visualize. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + n_groups : int, default=3 + Number of groups to divide the covariate into (e.g., 3 for tertiles). + title : str, default="Effect of Covariate on Survival" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + ci_alpha : float, default=0.2 + Transparency level for confidence interval bands. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_covariate_effect + >>> + >>> # Generate data with a continuous covariate + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> + >>> # Visualize the effect of the covariate on survival + >>> fig, ax = plot_covariate_effect(df, covariate_col="covariate", n_groups=3) + >>> plt.show() + """ + # Add a categorical version of the covariate + group_labels = [f"Q{i + 1}" for i in range(n_groups)] + data = data.copy() + data["_group"] = pd.qcut(data[covariate_col], q=n_groups, labels=group_labels) + + # Get the median value of each group for the legend + group_medians = data.groupby("_group")[covariate_col].median() + + # Create more informative labels + label_map = { + group: f"{group} ({covariate_col}≈{median:.2f})" + for group, median in group_medians.items() + } + + data["_label"] = data["_group"].map(label_map) + + # Create the plot + fig, ax = plot_survival_curve( + data=data, + time_col=time_col, + status_col=status_col, + group_col="_label", + confidence_intervals=True, + title=title, + figsize=figsize, + ci_alpha=ci_alpha, + ) + + return fig, ax + + +def describe_survival( + data: pd.DataFrame, time_col: str = "time", status_col: str = "status" +) -> pd.DataFrame: + """ + Generate a summary of survival data including median survival time, + event counts, and other descriptive statistics. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + + Returns + ------- + pd.DataFrame + Summary statistics dataframe. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import describe_survival + >>> + >>> # Generate data + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> + >>> # Get survival summary + >>> summary = describe_survival(df) + >>> print(summary) + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import KaplanMeierFitter + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + n_total = len(data) + n_events = data[status_col].sum() + n_censored = n_total - n_events + event_rate = n_events / n_total + + # Calculate median and other percentiles + kmf = KaplanMeierFitter() + kmf.fit(data[time_col], data[status_col]) + median = kmf.median_survival_time_ + + # Time ranges + time_min = data[time_col].min() + time_max = data[time_col].max() + time_mean = data[time_col].mean() + + # Create summary DataFrame + summary = pd.DataFrame( + { + "Metric": [ + "Total Observations", + "Number of Events", + "Number Censored", + "Event Rate", + "Median Survival Time", + "Min Time", + "Max Time", + "Mean Time", + ], + "Value": [ + n_total, + n_events, + n_censored, + f"{event_rate:.2%}", + f"{median:.4f}", + f"{time_min:.4f}", + f"{time_max:.4f}", + f"{time_mean:.4f}", + ], + } + ) + + return summary diff --git a/pyproject.toml b/pyproject.toml index 5048a06..cb198cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,26 +9,45 @@ packages = [{ include = "gen_surv" }] homepage = "https://github.com/DiogoRibeiro7/genSurvPy" repository = "https://github.com/DiogoRibeiro7/genSurvPy" documentation = "https://gensurvpy.readthedocs.io/en/stable/" +keywords = ["survival-analysis", "simulation", "cox-model", "markov-model", "time-dependent", "statistics"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "Intended Audience :: Healthcare Industry", + "Topic :: Scientific/Engineering :: Medical Science Apps.", + "Topic :: Scientific/Engineering :: Mathematics", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", +] [tool.poetry.dependencies] python = "^3.9" numpy = "^1.26" pandas = "^2.2.3" -pytest-cov = "^6.1.1" -invoke = "^2.2.0" typer = "^0.12.3" -tomli = "^2.2.1" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" +pytest-cov = "^6.1.1" python-semantic-release = "^9.21.0" mypy = "^1.15.0" invoke = "^2.2.0" hypothesis = "^6.98" tomli = "^2.2.1" +black = "^24.1.0" +isort = "^5.13.2" +flake8 = "^6.1.0" [tool.poetry.group.docs.dependencies] +sphinx = "^7.2.6" myst-parser = "<4.0.0" +sphinx-rtd-theme = "^1.3.0" + +[tool.poetry.scripts] +gen_surv = "gen_surv.cli:app" [tool.semantic_release] version_source = "tag" @@ -39,6 +58,22 @@ upload_to_repository = false branch = "main" build_command = "" +[tool.black] +line-length = 88 +target-version = ['py39'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 88 + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_aft.py b/tests/test_aft.py index 2688144..3d673bd 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,7 +1,126 @@ +""" +Tests for Accelerated Failure Time (AFT) models. +""" + import pandas as pd -from gen_surv.aft import gen_aft_log_normal +import pytest +import numpy as np +from hypothesis import given, strategies as st + +from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic + + +def test_gen_aft_log_logistic_runs(): + """Test that the Log-Logistic AFT generator runs without errors.""" + df = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_log_logistic_invalid_shape(): + """Test that the Log-Logistic AFT generator raises error for invalid shape.""" + with pytest.raises(ValueError, match="shape parameter must be positive"): + gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=-1.0, # Invalid negative shape + scale=2.0, + model_cens="uniform", + cens_par=5.0 + ) + + +def test_gen_aft_log_logistic_invalid_scale(): + """Test that the Log-Logistic AFT generator raises error for invalid scale.""" + with pytest.raises(ValueError, match="scale parameter must be positive"): + gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=0.0, # Invalid zero scale + model_cens="uniform", + cens_par=5.0 + ) + + + +@given( + n=st.integers(min_value=1, max_value=20), + shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), + scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), + cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), + seed=st.integers(min_value=0, max_value=1000) +) +def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): + """Property-based test for the Log-Logistic AFT generator.""" + df = gen_aft_log_logistic( + n=n, + beta=[0.5, -0.2], + shape=shape, + scale=scale, + model_cens="uniform", + cens_par=cens_par, + seed=seed + ) + assert df.shape[0] == n + assert set(df["status"].unique()).issubset({0, 1}) + assert (df["time"] >= 0).all() + assert df.filter(regex="^X[0-9]+$").shape[1] == 2 + + +def test_gen_aft_log_logistic_reproducibility(): + """Test that the Log-Logistic AFT generator is reproducible with the same seed.""" + df1 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + + df2 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + + pd.testing.assert_frame_equal(df1, df2) + + df3 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=43 # Different seed + ) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) + def test_gen_aft_log_normal_runs(): + """Test that the log-normal AFT generator runs without errors.""" df = gen_aft_log_normal( n=10, beta=[0.5, -0.2], @@ -16,4 +135,125 @@ def test_gen_aft_log_normal_runs(): assert "status" in df.columns assert "X0" in df.columns assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1}) \ No newline at end of file + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_weibull_runs(): + """Test that the Weibull AFT generator runs without errors.""" + df = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_weibull_invalid_shape(): + """Test that the Weibull AFT generator raises error for invalid shape.""" + with pytest.raises(ValueError, match="shape parameter must be positive"): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=-1.0, # Invalid negative shape + scale=2.0, + model_cens="uniform", + cens_par=5.0 + ) + + +def test_gen_aft_weibull_invalid_scale(): + """Test that the Weibull AFT generator raises error for invalid scale.""" + with pytest.raises(ValueError, match="scale parameter must be positive"): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=0.0, # Invalid zero scale + model_cens="uniform", + cens_par=5.0 + ) + + +def test_gen_aft_weibull_invalid_cens_model(): + """Test that the Weibull AFT generator raises error for invalid censoring model.""" + with pytest.raises(ValueError, match="model_cens must be 'uniform' or 'exponential'"): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="invalid", # Invalid censoring model + cens_par=5.0 + ) + + +@given( + n=st.integers(min_value=1, max_value=20), + shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), + scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), + cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), + seed=st.integers(min_value=0, max_value=1000) +) +def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed): + """Property-based test for the Weibull AFT generator.""" + df = gen_aft_weibull( + n=n, + beta=[0.5, -0.2], + shape=shape, + scale=scale, + model_cens="uniform", + cens_par=cens_par, + seed=seed + ) + assert df.shape[0] == n + assert set(df["status"].unique()).issubset({0, 1}) + assert (df["time"] >= 0).all() + assert df.filter(regex="^X[0-9]+$").shape[1] == 2 + + +def test_gen_aft_weibull_reproducibility(): + """Test that the Weibull AFT generator is reproducible with the same seed.""" + df1 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + + df2 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + + pd.testing.assert_frame_equal(df1, df2) + + df3 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=43 # Different seed + ) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py new file mode 100644 index 0000000..0431f47 --- /dev/null +++ b/tests/test_competing_risks.py @@ -0,0 +1,212 @@ +""" +Tests for Competing Risks models. +""" + +import pytest +import numpy as np +import pandas as pd +from hypothesis import given, strategies as st + +from gen_surv.competing_risks import ( + gen_competing_risks, + gen_competing_risks_weibull, + cause_specific_cumulative_incidence +) + + +def test_gen_competing_risks_basic(): + """Test that the competing risks generator runs without errors.""" + df = gen_competing_risks( + n=10, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1, 2}) + + +def test_gen_competing_risks_weibull_basic(): + """Test that the Weibull competing risks generator runs without errors.""" + df = gen_competing_risks_weibull( + n=10, + n_risks=2, + shape_params=[0.8, 1.5], + scale_params=[2.0, 3.0], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1, 2}) + + +def test_competing_risks_parameters(): + """Test parameter validation in competing risks model.""" + # Test with invalid number of baseline hazards + with pytest.raises(ValueError, match="Expected 3 baseline hazards"): + gen_competing_risks( + n=10, + n_risks=3, + baseline_hazards=[0.5, 0.3], # Only 2 provided, but 3 risks + seed=42 + ) + + # Test with invalid number of beta coefficient sets + with pytest.raises(ValueError, match="Expected 2 sets of coefficients"): + gen_competing_risks( + n=10, + n_risks=2, + betas=[[0.8, -0.5]], # Only 1 set provided, but 2 risks + seed=42 + ) + + # Test with invalid censoring model + with pytest.raises(ValueError, match="model_cens must be 'uniform' or 'exponential'"): + gen_competing_risks( + n=10, + n_risks=2, + model_cens="invalid", + seed=42 + ) + + +def test_competing_risks_weibull_parameters(): + """Test parameter validation in Weibull competing risks model.""" + # Test with invalid number of shape parameters + with pytest.raises(ValueError, match="Expected 3 shape parameters"): + gen_competing_risks_weibull( + n=10, + n_risks=3, + shape_params=[0.8, 1.5], # Only 2 provided, but 3 risks + seed=42 + ) + + # Test with invalid number of scale parameters + with pytest.raises(ValueError, match="Expected 3 scale parameters"): + gen_competing_risks_weibull( + n=10, + n_risks=3, + scale_params=[2.0, 3.0], # Only 2 provided, but 3 risks + seed=42 + ) + + +def test_cause_specific_cumulative_incidence(): + """Test the cause-specific cumulative incidence function.""" + # Generate some data + df = gen_competing_risks( + n=50, + n_risks=2, + baseline_hazards=[0.5, 0.3], + seed=42 + ) + + # Calculate CIF for cause 1 + time_points = np.linspace(0, 5, 10) + cif = cause_specific_cumulative_incidence(df, time_points, cause=1) + + assert isinstance(cif, pd.DataFrame) + assert len(cif) == len(time_points) + assert "time" in cif.columns + assert "incidence" in cif.columns + assert (cif["incidence"] >= 0).all() + assert (cif["incidence"] <= 1).all() + assert cif["incidence"].is_monotonic_increasing + + # Test with invalid cause + with pytest.raises(ValueError, match="Cause 3 not found in the data"): + cause_specific_cumulative_incidence(df, time_points, cause=3) + + +@given( + n=st.integers(min_value=5, max_value=50), + n_risks=st.integers(min_value=2, max_value=4), + seed=st.integers(min_value=0, max_value=1000) +) +def test_competing_risks_properties(n, n_risks, seed): + """Property-based tests for the competing risks model.""" + df = gen_competing_risks( + n=n, + n_risks=n_risks, + seed=seed + ) + + # Check basic properties + assert df.shape[0] == n + assert all(col in df.columns for col in ["id", "time", "status"]) + assert (df["time"] >= 0).all() + assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks + + # Count of each status + status_counts = df["status"].value_counts() + # There should be at least one of each status (including censoring) + # This might occasionally fail due to randomness, so we'll just check that + # we have at least 2 different status values + assert len(status_counts) >= 2 + + +@given( + n=st.integers(min_value=5, max_value=50), + n_risks=st.integers(min_value=2, max_value=4), + seed=st.integers(min_value=0, max_value=1000) +) +def test_competing_risks_weibull_properties(n, n_risks, seed): + """Property-based tests for the Weibull competing risks model.""" + df = gen_competing_risks_weibull( + n=n, + n_risks=n_risks, + seed=seed + ) + + # Check basic properties + assert df.shape[0] == n + assert all(col in df.columns for col in ["id", "time", "status"]) + assert (df["time"] >= 0).all() + assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks + + # Count of each status + status_counts = df["status"].value_counts() + # There should be at least 2 different status values + assert len(status_counts) >= 2 + + +def test_reproducibility(): + """Test that results are reproducible with the same seed.""" + df1 = gen_competing_risks( + n=20, + n_risks=2, + seed=42 + ) + + df2 = gen_competing_risks( + n=20, + n_risks=2, + seed=42 + ) + + pd.testing.assert_frame_equal(df1, df2) + + # Different seeds should produce different results + df3 = gen_competing_risks( + n=20, + n_risks=2, + seed=43 + ) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) diff --git a/tests/test_cphm.py b/tests/test_cphm.py index 05cc652..c699fab 100644 --- a/tests/test_cphm.py +++ b/tests/test_cphm.py @@ -1,13 +1,49 @@ -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +""" +Tests for the Cox Proportional Hazards Model (CPHM) generator. +""" + +import pytest +import pandas as pd from gen_surv.cphm import gen_cphm + def test_gen_cphm_output_shape(): + """Test that the output DataFrame has the expected shape and columns.""" df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) assert df.shape == (50, 3) assert list(df.columns) == ["time", "status", "covariate"] + def test_gen_cphm_status_range(): + """Test that status values are binary (0 or 1).""" df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covar=1.5) assert df["status"].isin([0, 1]).all() + + +def test_gen_cphm_time_positive(): + """Test that all time values are positive.""" + df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + assert (df["time"] > 0).all() + + +def test_gen_cphm_covariate_range(): + """Test that covariate values are within the specified range.""" + covar_max = 2.5 + df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=covar_max) + assert (df["covariate"] >= 0).all() + assert (df["covariate"] <= covar_max).all() + + +def test_gen_cphm_seed_reproducibility(): + """Test that setting the same seed produces identical results.""" + df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) + df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) + pd.testing.assert_frame_equal(df1, df2) + + +def test_gen_cphm_different_seeds(): + """Test that different seeds produce different results.""" + df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) + df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=43) + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df2) From 7f7c28ac93fc42b1a939dafb7eb59cd9b0ac3cb2 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Thu, 10 Jul 2025 23:27:41 +0100 Subject: [PATCH 02/19] feat: Add competing risks models and enhance data visualization --- gen_surv/competing_risks.py | 2 +- gen_surv/interface.py | 2 +- gen_surv/mixture.py | 281 +++++++++++++++++++++++++++++++ gen_surv/piecewise.py | 321 ++++++++++++++++++++++++++++++++++++ gen_surv/visualization.py | 3 +- tests/test_aft.py | 1 - 6 files changed, 606 insertions(+), 4 deletions(-) create mode 100644 gen_surv/mixture.py create mode 100644 gen_surv/piecewise.py diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py index 239d327..2dfd073 100644 --- a/gen_surv/competing_risks.py +++ b/gen_surv/competing_risks.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from typing import Dict, List, Optional, Tuple, Union, Literal +from typing import Dict, List, Optional, Tuple, Union, Literal, Any def gen_competing_risks( diff --git a/gen_surv/interface.py b/gen_surv/interface.py index f3935d0..4298203 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -6,7 +6,7 @@ >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) """ -from typing import Any, Dict, Literal, Optional, Union, List, Tuple, cast +from typing import Any, Literal import pandas as pd from gen_surv.cphm import gen_cphm diff --git a/gen_surv/mixture.py b/gen_surv/mixture.py new file mode 100644 index 0000000..305606f --- /dev/null +++ b/gen_surv/mixture.py @@ -0,0 +1,281 @@ +""" +Mixture Cure Models for survival data simulation. + +This module provides functions to generate survival data with a cure fraction, +i.e., a proportion of subjects who are immune to the event of interest. +""" + +import numpy as np +import pandas as pd +from typing import Dict, List, Optional, Tuple, Union, Literal + + +def gen_mixture_cure( + n: int, + cure_fraction: float, + baseline_hazard: float = 0.5, + betas_survival: Optional[List[float]] = None, + betas_cure: Optional[List[float]] = None, + n_covariates: int = 2, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + max_time: Optional[float] = 10.0, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Generate survival data with a cure fraction using a mixture cure model. + + Parameters + ---------- + n : int + Number of subjects. + cure_fraction : float + Baseline probability of being cured (immune to the event). + Should be between 0 and 1. + baseline_hazard : float, default=0.5 + Baseline hazard rate for the non-cured population. + betas_survival : list of float, optional + Coefficients for covariates in the survival component. + If None, generates random coefficients. + betas_cure : list of float, optional + Coefficients for covariates in the cure component. + If None, generates random coefficients. + n_covariates : int, default=2 + Number of covariates to generate if betas is None. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (1=event, 0=censored) + - "cured": Indicator of cure status (1=cured, 0=not cured) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.mixture import gen_mixture_cure + >>> + >>> # Generate data with 30% baseline cure fraction + >>> df = gen_mixture_cure( + ... n=100, + ... cure_fraction=0.3, + ... betas_survival=[0.8, -0.5], + ... betas_cure=[-0.5, 0.8], + ... seed=42 + ... ) + >>> + >>> # Check cure proportion + >>> print(f"Cured subjects: {df['cured'].mean():.2%}") + """ + if seed is not None: + np.random.seed(seed) + + # Validate inputs + if not 0 <= cure_fraction <= 1: + raise ValueError("cure_fraction must be between 0 and 1") + + if baseline_hazard <= 0: + raise ValueError("baseline_hazard must be positive") + + # Set default covariate parameters if not provided + if covariate_params is None: + if covariate_dist == "normal": + covariate_params = {"mean": 0.0, "std": 1.0} + elif covariate_dist == "uniform": + covariate_params = {"low": 0.0, "high": 1.0} + elif covariate_dist == "binary": + covariate_params = {"p": 0.5} + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Set default betas if not provided + if betas_survival is None: + betas_survival = np.random.normal(0, 0.5, size=n_covariates) + else: + betas_survival = np.array(betas_survival) + n_covariates = len(betas_survival) + + if betas_cure is None: + betas_cure = np.random.normal(0, 0.5, size=n_covariates) + else: + betas_cure = np.array(betas_cure) + if len(betas_cure) != n_covariates: + raise ValueError( + f"betas_cure must have the same length as betas_survival, " + f"got {len(betas_cure)} vs {n_covariates}" + ) + + # Generate covariates + if covariate_dist == "normal": + X = np.random.normal( + covariate_params.get("mean", 0.0), + covariate_params.get("std", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "uniform": + X = np.random.uniform( + covariate_params.get("low", 0.0), + covariate_params.get("high", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "binary": + X = np.random.binomial( + 1, + covariate_params.get("p", 0.5), + size=(n, n_covariates) + ) + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Calculate linear predictors + lp_survival = X @ betas_survival + lp_cure = X @ betas_cure + + # Determine cure status (logistic model) + cure_probs = 1 / (1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure))) + cured = np.random.binomial(1, cure_probs) + + # Generate survival times + survival_times = np.zeros(n) + + # For non-cured subjects, generate event times + non_cured_indices = np.where(cured == 0)[0] + + for i in non_cured_indices: + # Adjust hazard rate by covariate effect + adjusted_hazard = baseline_hazard * np.exp(lp_survival[i]) + + # Generate exponential survival time + survival_times[i] = np.random.exponential(scale=1/adjusted_hazard) + + # For cured subjects, set "infinite" survival time + cured_indices = np.where(cured == 1)[0] + if max_time is not None: + survival_times[cured_indices] = max_time * 100 # Effectively infinite + else: + survival_times[cured_indices] = np.inf # Actually infinite + + # Generate censoring times + if model_cens == "uniform": + cens_times = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + cens_times = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Determine observed time and status + observed_times = np.minimum(survival_times, cens_times) + status = (survival_times <= cens_times).astype(int) + + # Cap times at max_time if specified + if max_time is not None: + over_max = observed_times > max_time + observed_times[over_max] = max_time + status[over_max] = 0 # Censored if beyond max_time + + # Create DataFrame + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_times, + "status": status, + "cured": cured + }) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def cure_fraction_estimate( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + bandwidth: float = 0.1 +) -> float: + """ + Estimate the cure fraction from observed data using non-parametric methods. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with survival data. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (1=event, 0=censored). + bandwidth : float, default=0.1 + Bandwidth parameter for smoothing the tail of the survival curve. + + Returns + ------- + float + Estimated cure fraction. + + Notes + ----- + This function uses a non-parametric approach to estimate the cure fraction + based on the plateau of the survival curve. It may not be accurate for + small sample sizes or heavy censoring. + """ + # Sort data by time + sorted_data = data.sort_values(by=time_col).copy() + + # Calculate Kaplan-Meier estimate + times = sorted_data[time_col].values + status = sorted_data[status_col].values + n = len(times) + + if n == 0: + return 0.0 + + # Calculate survival function + survival = np.ones(n) + + for i in range(n): + if i > 0: + survival[i] = survival[i-1] + + # Count subjects at risk at this time + at_risk = n - i + + if status[i] == 1: # Event + survival[i] *= (1 - 1/at_risk) + + # Estimate cure fraction as the plateau of the survival curve + # Use the last 10% of the survival curve if enough data points + tail_size = max(int(n * 0.1), 1) + tail_survival = survival[-tail_size:] + + # Apply smoothing if there are enough data points + if tail_size > 3: + # Use kernel smoothing + weights = np.exp(-(np.arange(tail_size) - tail_size + 1)**2 / (2 * bandwidth * tail_size)**2) + weights = weights / weights.sum() + cure_fraction = np.sum(tail_survival * weights) + else: + # Just use the last survival probability + cure_fraction = survival[-1] + + return cure_fraction diff --git a/gen_surv/piecewise.py b/gen_surv/piecewise.py new file mode 100644 index 0000000..fb2abc1 --- /dev/null +++ b/gen_surv/piecewise.py @@ -0,0 +1,321 @@ +""" +Piecewise Exponential survival models. + +This module provides functions for generating survival data from piecewise +exponential distributions with time-dependent hazards. +""" + +import numpy as np +import pandas as pd +from typing import Dict, List, Optional, Tuple, Union, Literal + + +def gen_piecewise_exponential( + n: int, + breakpoints: List[float], + hazard_rates: List[float], + betas: Optional[Union[List[float], np.ndarray]] = None, + n_covariates: int = 2, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: Optional[int] = None +) -> pd.DataFrame: + """ + Generate survival data using a piecewise exponential distribution. + + Parameters + ---------- + n : int + Number of subjects. + breakpoints : list of float + Time points where hazard rates change. Must be in ascending order. + The first interval is [0, breakpoints[0]), the second is [breakpoints[0], breakpoints[1]), etc. + hazard_rates : list of float + Hazard rates for each interval. Length should be len(breakpoints) + 1. + betas : list or array, optional + Coefficients for covariates. If None, generates random coefficients. + n_covariates : int, default=2 + Number of covariates to generate if betas is None. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (1=event, 0=censored) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.piecewise import gen_piecewise_exponential + >>> + >>> # Generate data with 3 intervals (increasing hazard) + >>> df = gen_piecewise_exponential( + ... n=100, + ... breakpoints=[1.0, 3.0], + ... hazard_rates=[0.2, 0.5, 1.0], + ... betas=[0.8, -0.5], + ... seed=42 + ... ) + """ + if seed is not None: + np.random.seed(seed) + + # Validate inputs + if len(hazard_rates) != len(breakpoints) + 1: + raise ValueError(f"Expected {len(breakpoints) + 1} hazard rates, got {len(hazard_rates)}") + + if not all(b > 0 for b in breakpoints): + raise ValueError("All breakpoints must be positive") + + if not all(h > 0 for h in hazard_rates): + raise ValueError("All hazard rates must be positive") + + if not all(breakpoints[i] < breakpoints[i+1] for i in range(len(breakpoints)-1)): + raise ValueError("Breakpoints must be in ascending order") + + # Set default covariate parameters if not provided + if covariate_params is None: + if covariate_dist == "normal": + covariate_params = {"mean": 0.0, "std": 1.0} + elif covariate_dist == "uniform": + covariate_params = {"low": 0.0, "high": 1.0} + elif covariate_dist == "binary": + covariate_params = {"p": 0.5} + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Set default betas if not provided + if betas is None: + betas = np.random.normal(0, 0.5, size=n_covariates) + else: + betas = np.array(betas) + n_covariates = len(betas) + + # Generate covariates + if covariate_dist == "normal": + X = np.random.normal( + covariate_params.get("mean", 0.0), + covariate_params.get("std", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "uniform": + X = np.random.uniform( + covariate_params.get("low", 0.0), + covariate_params.get("high", 1.0), + size=(n, n_covariates) + ) + elif covariate_dist == "binary": + X = np.random.binomial( + 1, + covariate_params.get("p", 0.5), + size=(n, n_covariates) + ) + else: + raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + + # Calculate linear predictor + linear_predictor = X @ betas + + # Generate survival times using piecewise exponential distribution + survival_times = np.zeros(n) + + for i in range(n): + # Adjust hazard rates by the covariate effect + adjusted_hazard_rates = [h * np.exp(linear_predictor[i]) for h in hazard_rates] + + # Generate random uniform between 0 and 1 + u = np.random.uniform(0, 1) + + # Calculate survival time using inverse CDF method for piecewise exponential + remaining_time = -np.log(u) # Initial time remaining (for standard exponential) + total_time = 0.0 + + # Start with the first interval [0, breakpoints[0]) + interval_width = breakpoints[0] + hazard = adjusted_hazard_rates[0] + time_to_consume = remaining_time / hazard + + if time_to_consume < interval_width: + # Event occurs in first interval + survival_times[i] = time_to_consume + continue + + # Event occurs after first interval + total_time += interval_width + remaining_time -= hazard * interval_width + + # Go through middle intervals [breakpoints[j-1], breakpoints[j]) + for j in range(1, len(breakpoints)): + interval_width = breakpoints[j] - breakpoints[j-1] + hazard = adjusted_hazard_rates[j] + time_to_consume = remaining_time / hazard + + if time_to_consume < interval_width: + # Event occurs in this interval + survival_times[i] = total_time + time_to_consume + break + + # Event occurs after this interval + total_time += interval_width + remaining_time -= hazard * interval_width + + # If we've gone through all intervals and still no event, + # use the last hazard rate for the remainder + if remaining_time > 0: + hazard = adjusted_hazard_rates[-1] + survival_times[i] = total_time + remaining_time / hazard + + # Generate censoring times + if model_cens == "uniform": + cens_times = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + cens_times = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + # Determine observed time and status + observed_times = np.minimum(survival_times, cens_times) + status = (survival_times <= cens_times).astype(int) + + # Create DataFrame + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_times, + "status": status + }) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def piecewise_hazard_function( + t: Union[float, np.ndarray], + breakpoints: List[float], + hazard_rates: List[float] +) -> Union[float, np.ndarray]: + """ + Calculate the hazard function value at time t for a piecewise exponential distribution. + + Parameters + ---------- + t : float or array + Time point(s) at which to evaluate the hazard function. + breakpoints : list of float + Time points where hazard rates change. + hazard_rates : list of float + Hazard rates for each interval. + + Returns + ------- + float or array + Hazard function value(s) at time t. + """ + # Convert scalar input to array for consistent processing + scalar_input = np.isscalar(t) + t_array = np.atleast_1d(t) + result = np.zeros_like(t_array) + + # Assign hazard rates based on time intervals + result[t_array < 0] = 0 # Hazard is 0 for negative times + + # First interval: [0, breakpoints[0]) + mask = (t_array >= 0) & (t_array < breakpoints[0]) + result[mask] = hazard_rates[0] + + # Middle intervals: [breakpoints[j-1], breakpoints[j]) + for j in range(1, len(breakpoints)): + mask = (t_array >= breakpoints[j-1]) & (t_array < breakpoints[j]) + result[mask] = hazard_rates[j] + + # Last interval: [breakpoints[-1], infinity) + mask = t_array >= breakpoints[-1] + result[mask] = hazard_rates[-1] + + return result[0] if scalar_input else result + + +def piecewise_survival_function( + t: Union[float, np.ndarray], + breakpoints: List[float], + hazard_rates: List[float] +) -> Union[float, np.ndarray]: + """ + Calculate the survival function at time t for a piecewise exponential distribution. + + Parameters + ---------- + t : float or array + Time point(s) at which to evaluate the survival function. + breakpoints : list of float + Time points where hazard rates change. + hazard_rates : list of float + Hazard rates for each interval. + + Returns + ------- + float or array + Survival function value(s) at time t. + """ + # Convert scalar input to array for consistent processing + scalar_input = np.isscalar(t) + t_array = np.atleast_1d(t) + result = np.ones_like(t_array) + + # For each time point, calculate the survival function + for i, ti in enumerate(t_array): + if ti <= 0: + continue # Survival probability is 1 at time 0 or earlier + + cumulative_hazard = 0.0 + + # First interval: [0, min(ti, breakpoints[0])) + first_interval_end = min(ti, breakpoints[0]) if breakpoints else ti + cumulative_hazard += hazard_rates[0] * first_interval_end + + if ti <= breakpoints[0]: + result[i] = np.exp(-cumulative_hazard) + continue + + # Middle intervals: [breakpoints[j-1], min(ti, breakpoints[j])) + for j in range(1, len(breakpoints)): + if ti <= breakpoints[j-1]: + break + + interval_start = breakpoints[j-1] + interval_end = min(ti, breakpoints[j]) + interval_width = interval_end - interval_start + + cumulative_hazard += hazard_rates[j] * interval_width + + if ti <= breakpoints[j]: + break + + # Last interval: [breakpoints[-1], ti) + if ti > breakpoints[-1]: + last_interval_width = ti - breakpoints[-1] + cumulative_hazard += hazard_rates[-1] * last_interval_width + + result[i] = np.exp(-cumulative_hazard) + + return result[0] if scalar_input else result diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py index 8700b30..07ff706 100644 --- a/gen_surv/visualization.py +++ b/gen_surv/visualization.py @@ -82,7 +82,8 @@ def plot_survival_curve( # Create separate KM curves for each group (if specified) if group_col is not None: groups = data[group_col].unique() - colors = plt.cm.tab10.colors[: len(groups)] + cmap = plt.get_cmap("tab10") + colors = [cmap(i) for i in range(len(groups))] for i, group in enumerate(groups): mask = data[group_col] == group diff --git a/tests/test_aft.py b/tests/test_aft.py index 3d673bd..a7ba60d 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -4,7 +4,6 @@ import pandas as pd import pytest -import numpy as np from hypothesis import given, strategies as st from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic From db4973820b6afd1b7bb341baf749df9fda956abc Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Fri, 11 Jul 2025 22:42:26 +0100 Subject: [PATCH 03/19] chore: fixes --- gen_surv/competing_risks.py | 1 + tasks.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py index 2dfd073..4c11f44 100644 --- a/gen_surv/competing_risks.py +++ b/gen_surv/competing_risks.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +import matplotlib.pyplot as plt from typing import Dict, List, Optional, Tuple, Union, Literal, Any diff --git a/tasks.py b/tasks.py index d37bc3a..4b16197 100644 --- a/tasks.py +++ b/tasks.py @@ -4,8 +4,6 @@ import shlex - - @task def test(c: Context) -> None: """Run the test suite with coverage reporting. From 79f6a8b48d3ad34a5e239ced6b37a2f0a3dfede2 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Fri, 11 Jul 2025 22:47:45 +0100 Subject: [PATCH 04/19] chore: fixes unit test --- gen_surv/visualization.py | 11 +++++++---- tests/test_aft.py | 4 +++- tests/test_bivariate.py | 4 ++++ tests/test_cli.py | 13 +++++++++++-- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py index 07ff706..08f544a 100644 --- a/gen_surv/visualization.py +++ b/gen_surv/visualization.py @@ -1,11 +1,13 @@ """ Visualization utilities for survival data. -This module provides functions to visualize survival data generated by gen_surv, -including Kaplan-Meier survival curves and other commonly used plots in survival analysis. +This module provides functions to visualize survival data generated by +gen_surv, +including Kaplan-Meier survival curves and other commonly used plots in +survival analysis. """ -from typing import Dict, List, Optional, Tuple, Union, Any +from typing import Dict, Optional, Tuple, Union, Any import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -217,7 +219,8 @@ def plot_covariate_effect( ci_alpha: float = 0.2, ) -> Tuple[Figure, Axes]: """ - Visualize the effect of a continuous covariate on survival by discretizing it. + Visualize the effect of a continuous covariate on survival by discretizing + it. Parameters ---------- diff --git a/tests/test_aft.py b/tests/test_aft.py index a7ba60d..04330e9 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,11 +1,13 @@ """ Tests for Accelerated Failure Time (AFT) models. """ - +import os +import sys import pandas as pd import pytest from hypothesis import given, strategies as st +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py index b011130..d1ae7b6 100644 --- a/tests/test_bivariate.py +++ b/tests/test_bivariate.py @@ -1,4 +1,8 @@ +import os +import sys import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.bivariate import sample_bivariate_distribution import pytest diff --git a/tests/test_cli.py b/tests/test_cli.py index d5fd8d7..cce5815 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,19 @@ +import sys +import os +import runpy + import pandas as pd + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.cli import dataset -import runpy def test_cli_dataset_stdout(monkeypatch, capsys): - """Dataset command prints CSV to stdout when no output file is given.""" + """ + Test that the 'dataset' CLI command prints the generated CSV data to stdout when no output file is specified. + This test patches the 'generate' function to return a simple DataFrame, invokes the CLI command directly, + and asserts that the expected CSV header appears in the captured standard output. + """ def fake_generate(model: str, n: int): return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) From 93b7f1ef548c37c41c9a617c7d5574efecc73c94 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Fri, 11 Jul 2025 22:58:08 +0100 Subject: [PATCH 05/19] chore: solve issues with formatting --- tests/test_aft.py | 96 ++++++++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/tests/test_aft.py b/tests/test_aft.py index 04330e9..e3e9a6e 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,6 +1,7 @@ """ Tests for Accelerated Failure Time (AFT) models. """ + import os import sys import pandas as pd @@ -20,7 +21,7 @@ def test_gen_aft_log_logistic_runs(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -32,7 +33,8 @@ def test_gen_aft_log_logistic_runs(): def test_gen_aft_log_logistic_invalid_shape(): - """Test that the Log-Logistic AFT generator raises error for invalid shape.""" + """Test that the Log-Logistic AFT generator raises error + for invalid shape.""" with pytest.raises(ValueError, match="shape parameter must be positive"): gen_aft_log_logistic( n=10, @@ -40,12 +42,13 @@ def test_gen_aft_log_logistic_invalid_shape(): shape=-1.0, # Invalid negative shape scale=2.0, model_cens="uniform", - cens_par=5.0 + cens_par=5.0, ) def test_gen_aft_log_logistic_invalid_scale(): - """Test that the Log-Logistic AFT generator raises error for invalid scale.""" + """Test that the Log-Logistic AFT generator raises error + for invalid scale.""" with pytest.raises(ValueError, match="scale parameter must be positive"): gen_aft_log_logistic( n=10, @@ -53,17 +56,22 @@ def test_gen_aft_log_logistic_invalid_scale(): shape=1.5, scale=0.0, # Invalid zero scale model_cens="uniform", - cens_par=5.0 + cens_par=5.0, ) - @given( n=st.integers(min_value=1, max_value=20), - shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), - scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), - cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), - seed=st.integers(min_value=0, max_value=1000) + shape=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), ) def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): """Property-based test for the Log-Logistic AFT generator.""" @@ -74,7 +82,7 @@ def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): scale=scale, model_cens="uniform", cens_par=cens_par, - seed=seed + seed=seed, ) assert df.shape[0] == n assert set(df["status"].unique()).issubset({0, 1}) @@ -83,7 +91,8 @@ def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): def test_gen_aft_log_logistic_reproducibility(): - """Test that the Log-Logistic AFT generator is reproducible with the same seed.""" + """Test that the Log-Logistic AFT generator is reproducible + with the same seed.""" df1 = gen_aft_log_logistic( n=10, beta=[0.5, -0.2], @@ -91,9 +100,9 @@ def test_gen_aft_log_logistic_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) - + df2 = gen_aft_log_logistic( n=10, beta=[0.5, -0.2], @@ -101,11 +110,11 @@ def test_gen_aft_log_logistic_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) - + pd.testing.assert_frame_equal(df1, df2) - + df3 = gen_aft_log_logistic( n=10, beta=[0.5, -0.2], @@ -113,9 +122,9 @@ def test_gen_aft_log_logistic_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=43 # Different seed + seed=43, # Different seed ) - + with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df3) @@ -123,12 +132,7 @@ def test_gen_aft_log_logistic_reproducibility(): def test_gen_aft_log_normal_runs(): """Test that the log-normal AFT generator runs without errors.""" df = gen_aft_log_normal( - n=10, - beta=[0.5, -0.2], - sigma=1.0, - model_cens="uniform", - cens_par=5.0, - seed=42 + n=10, beta=[0.5, -0.2], sigma=1.0, model_cens="uniform", cens_par=5.0, seed=42 ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -148,7 +152,7 @@ def test_gen_aft_weibull_runs(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -168,7 +172,7 @@ def test_gen_aft_weibull_invalid_shape(): shape=-1.0, # Invalid negative shape scale=2.0, model_cens="uniform", - cens_par=5.0 + cens_par=5.0, ) @@ -181,29 +185,37 @@ def test_gen_aft_weibull_invalid_scale(): shape=1.5, scale=0.0, # Invalid zero scale model_cens="uniform", - cens_par=5.0 + cens_par=5.0, ) def test_gen_aft_weibull_invalid_cens_model(): """Test that the Weibull AFT generator raises error for invalid censoring model.""" - with pytest.raises(ValueError, match="model_cens must be 'uniform' or 'exponential'"): + with pytest.raises( + ValueError, match="model_cens must be 'uniform' or 'exponential'" + ): gen_aft_weibull( n=10, beta=[0.5, -0.2], shape=1.5, scale=2.0, model_cens="invalid", # Invalid censoring model - cens_par=5.0 + cens_par=5.0, ) @given( n=st.integers(min_value=1, max_value=20), - shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), - scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False), - cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), - seed=st.integers(min_value=0, max_value=1000) + shape=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), ) def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed): """Property-based test for the Weibull AFT generator.""" @@ -214,7 +226,7 @@ def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed): scale=scale, model_cens="uniform", cens_par=cens_par, - seed=seed + seed=seed, ) assert df.shape[0] == n assert set(df["status"].unique()).issubset({0, 1}) @@ -231,9 +243,9 @@ def test_gen_aft_weibull_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) - + df2 = gen_aft_weibull( n=10, beta=[0.5, -0.2], @@ -241,11 +253,11 @@ def test_gen_aft_weibull_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) - + pd.testing.assert_frame_equal(df1, df2) - + df3 = gen_aft_weibull( n=10, beta=[0.5, -0.2], @@ -253,8 +265,8 @@ def test_gen_aft_weibull_reproducibility(): scale=2.0, model_cens="uniform", cens_par=5.0, - seed=43 # Different seed + seed=43, # Different seed ) - + with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df3) From 56e9e99098c63d7282d98c183bb80f12dbad56ad Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Thu, 24 Jul 2025 19:41:55 +0100 Subject: [PATCH 06/19] Add files via upload --- fix_recommendations.md | 106 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 fix_recommendations.md diff --git a/fix_recommendations.md b/fix_recommendations.md new file mode 100644 index 0000000..deb04eb --- /dev/null +++ b/fix_recommendations.md @@ -0,0 +1,106 @@ +# Fixing gen_surv Repository Issues + +## Priority 1: Critical Fixes + +### 1. Fix `__init__.py` Import Issues + +Update `gen_surv/__init__.py` to include missing imports: + +```python +# Add these imports +from .aft import gen_aft_log_logistic +from .competing_risks import gen_competing_risks, gen_competing_risks_weibull +from .mixture import gen_mixture_cure, cure_fraction_estimate +from .piecewise import gen_piecewise_exponential + +# Update __all__ to include: +__all__ = [ + # ... existing exports ... + "gen_aft_log_logistic", + "gen_competing_risks", + "gen_competing_risks_weibull", + "gen_mixture_cure", + "cure_fraction_estimate", + "gen_piecewise_exponential", +] +``` + +### 2. Add Missing Validators + +Create validation functions in `validate.py`: + +```python +def validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par): + # Add validation logic + +def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens_par): + # Add validation logic + +def validate_competing_risks_inputs(n, n_risks, baseline_hazards, betas, model_cens, cens_par): + # Add validation logic +``` + +### 3. Update CLI Integration + +Modify `cli.py` to handle all available models: + +```python +# Add support for competing_risks, mixture, and piecewise models +# Update parameter handling for each model type +``` + +## Priority 2: Version Consistency + +### Update Version Numbers + +1. **CITATION.cff**: Change version from "1.0.3" to "1.0.8" +2. **docs/source/conf.py**: Change release from '1.0.3' to '1.0.8' + +## Priority 3: Testing and Documentation + +### Add Missing Tests + +Create test files: +- `tests/test_censoring.py` +- `tests/test_mixture.py` +- `tests/test_piecewise.py` +- `tests/test_summary.py` +- `tests/test_visualization.py` + +### Update Documentation + +1. Add competing risks, mixture, and piecewise models to `theory.md` +2. Update examples in documentation to include new models + +## Priority 4: Code Quality Improvements + +### Standardize Parameter Naming + +- Consistently use `covariate_cols` instead of mixing `covar` and `covariate_cols` +- Standardize return column names across all models + +### Add Type Hints + +Complete type hints for all public functions, especially in: +- `mixture.py` +- `piecewise.py` +- `summary.py` + +## Verification Steps + +After implementing fixes: + +1. **Test imports**: `python -c "from gen_surv import gen_aft_log_logistic, gen_competing_risks"` +2. **Test CLI**: `python -m gen_surv dataset competing_risks --n 10` +3. **Run full test suite**: `poetry run pytest` +4. **Check version consistency**: `python scripts/check_version_match.py` +5. **Build docs**: `poetry run sphinx-build docs/source docs/build` + +## Impact Assessment + +These fixes will: +- ✅ Eliminate ImportError exceptions for users +- ✅ Make all models accessible via public API +- ✅ Ensure CLI works for all supported models +- ✅ Improve user experience and API consistency +- ✅ Maintain backward compatibility \ No newline at end of file From 05c693b8ec410b4a687088b5a74bef240d9e8351 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Sat, 26 Jul 2025 06:01:43 +0100 Subject: [PATCH 07/19] docs: update fix reference (#37) --- .gitignore | 1 + CITATION.cff | 2 +- README.md | 22 ++- docs/source/conf.py | 2 +- docs/source/index.md | 2 +- docs/source/theory.md | 23 ++- docs/source/usage.md | 2 +- examples/run_cphm.py | 2 +- fix_recommendations.md | 112 ++++---------- gen_surv/__init__.py | 31 ++-- gen_surv/cli.py | 112 +++++++++++--- gen_surv/cmm.py | 12 +- gen_surv/competing_risks.py | 294 +++++++++++++++++++----------------- gen_surv/cphm.py | 25 ++- gen_surv/interface.py | 32 +++- gen_surv/summary.py | 18 +-- gen_surv/thmm.py | 12 +- gen_surv/validate.py | 91 +++++++++-- gen_surv/visualization.py | 11 +- tests/test_censoring.py | 17 +++ tests/test_cmm.py | 10 +- tests/test_cphm.py | 22 +-- tests/test_mixture.py | 15 ++ tests/test_piecewise.py | 25 +++ tests/test_summary.py | 9 ++ tests/test_thmm.py | 10 +- tests/test_validate.py | 8 +- tests/test_visualization.py | 9 ++ 28 files changed, 593 insertions(+), 338 deletions(-) create mode 100644 tests/test_censoring.py create mode 100644 tests/test_mixture.py create mode 100644 tests/test_piecewise.py create mode 100644 tests/test_summary.py create mode 100644 tests/test_visualization.py diff --git a/.gitignore b/.gitignore index 4518323..6144971 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ dist/ # Temporary *.log *.tmp +.hypothesis/ diff --git a/CITATION.cff b/CITATION.cff index fe3f8e0..712bf4a 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,7 +5,7 @@ message: "If you use this software, please cite it using the metadata below." preferred-citation: type: software title: "gen_surv" - version: "1.0.3" + version: "1.0.8" url: "https://github.com/DiogoRibeiro7/genSurvPy" authors: - family-names: Ribeiro diff --git a/README.md b/README.md index b74802b..129ed7a 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ poetry install - Easy integration with `pandas` and `NumPy` - Suitable for benchmarking survival algorithms and teaching - Accelerated Failure Time (Log-Normal) model generator +- Mixture cure and piecewise exponential models +- Competing risks generators (constant and Weibull hazards) - Command-line interface powered by `Typer` ## 🧪 Example @@ -36,7 +38,7 @@ poetry install from gen_surv import generate # CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) # AFT Log-Normal generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) @@ -54,6 +56,18 @@ generate(model="tdcm", n=100, dist="weibull", corr=0.5, generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) + +# Mixture Cure +generate(model="mixture_cure", n=100, cure_fraction=0.3, seed=42) + +# Piecewise Exponential +generate( + model="piecewise_exponential", + n=100, + breakpoints=[1.0, 3.0], + hazard_rates=[0.2, 0.5, 1.0], + seed=0 +) ``` ## ⌨️ Command-Line Usage @@ -75,6 +89,12 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv | `gen_tdcm()` | Time-Dependent Covariate Model | | `gen_thmm()` | Time-Homogeneous Markov Model | | `gen_aft_log_normal()` | Accelerated Failure Time Log-Normal | +| `gen_aft_log_logistic()` | AFT model with log-logistic baseline | +| `gen_competing_risks()` | Competing risks with constant hazards | +| `gen_competing_risks_weibull()` | Competing risks with Weibull hazards | +| `gen_mixture_cure()` | Mixture cure model | +| `cure_fraction_estimate()` | Estimate cure fraction | +| `gen_piecewise_exponential()` | Piecewise exponential model | | `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times | | `runifcens()` | Generate uniform censoring times | | `rexpocens()` | Generate exponential censoring times | diff --git a/docs/source/conf.py b/docs/source/conf.py index 6ac6534..39fa69c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,7 +12,7 @@ project = 'gen_surv' copyright = '2025, Diogo Ribeiro' author = 'Diogo Ribeiro' -release = '1.0.3' +release = '1.0.8' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/index.md b/docs/source/index.md index 5a22562..5756e5b 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -36,7 +36,7 @@ theory from gen_surv import generate # CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) # AFT Log-Normal generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) diff --git a/docs/source/theory.md b/docs/source/theory.md index 957dce1..70a402f 100644 --- a/docs/source/theory.md +++ b/docs/source/theory.md @@ -116,5 +116,24 @@ This model is especially useful when the proportional hazards assumption is not All models support censoring: -- **Uniform:** \( C_i \\sim U(0, \\text{cens\\_par}) \) -- **Exponential:** \( C_i \\sim \\text{Exp}(\\text{cens\\_par}) \) +- **Uniform:** \( C_i \sim U(0, \text{cens\_par}) \) +- **Exponential:** \( C_i \sim \text{Exp}(\text{cens\_par}) \) + +## 6. Competing Risks Models + +These models simulate multiple mutually exclusive event types. Each cause has its +own hazard function, and the observed status indicates which event occurred +(1, 2, ...). The package includes constant-hazard and Weibull-hazard versions. + +## 7. Mixture Cure Models + +Mixture cure models assume a proportion of subjects are immune to the event. The +generator mixes a logistic cure component with an exponential hazard for the +uncured, returning a ``cured`` indicator column alongside the usual time and +status. + +## 8. Piecewise Exponential Model + +This model divides the time axis into intervals defined by user-supplied +breakpoints. Each interval has its own hazard rate, allowing flexible hazard +shapes over time. diff --git a/docs/source/usage.md b/docs/source/usage.md index e3f856e..240a96e 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -20,7 +20,7 @@ Generate datasets directly in Python: from gen_surv import generate # Cox Proportional Hazards example -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) ``` You can also generate data from the command line: diff --git a/examples/run_cphm.py b/examples/run_cphm.py index c02b01b..47c504c 100644 --- a/examples/run_cphm.py +++ b/examples/run_cphm.py @@ -10,7 +10,7 @@ model_cens="uniform", cens_par=1.0, beta=0.5, - covar=2.0, + covariate_range=2.0, seed=42 ) diff --git a/fix_recommendations.md b/fix_recommendations.md index deb04eb..942ff54 100644 --- a/fix_recommendations.md +++ b/fix_recommendations.md @@ -2,105 +2,45 @@ ## Priority 1: Critical Fixes -### 1. Fix `__init__.py` Import Issues - -Update `gen_surv/__init__.py` to include missing imports: - -```python -# Add these imports -from .aft import gen_aft_log_logistic -from .competing_risks import gen_competing_risks, gen_competing_risks_weibull -from .mixture import gen_mixture_cure, cure_fraction_estimate -from .piecewise import gen_piecewise_exponential - -# Update __all__ to include: -__all__ = [ - # ... existing exports ... - "gen_aft_log_logistic", - "gen_competing_risks", - "gen_competing_risks_weibull", - "gen_mixture_cure", - "cure_fraction_estimate", - "gen_piecewise_exponential", -] -``` - -### 2. Add Missing Validators - -Create validation functions in `validate.py`: - -```python -def validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par): - # Add validation logic - -def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens_par): - # Add validation logic - -def validate_competing_risks_inputs(n, n_risks, baseline_hazards, betas, model_cens, cens_par): - # Add validation logic -``` - -### 3. Update CLI Integration - -Modify `cli.py` to handle all available models: - -```python -# Add support for competing_risks, mixture, and piecewise models -# Update parameter handling for each model type -``` +- [x] **Fix `__init__.py` Import Issues** + - Ensure missing imports for new generators are added and exported via `__all__`. -## Priority 2: Version Consistency - -### Update Version Numbers +- [x] **Add Missing Validators** + - Create validation helpers for AFT Weibull, AFT log-logistic, and competing risks generators. -1. **CITATION.cff**: Change version from "1.0.3" to "1.0.8" -2. **docs/source/conf.py**: Change release from '1.0.3' to '1.0.8' +- [x] **Update CLI Integration** + - Support competing risks, mixture cure, and piecewise exponential models. -## Priority 3: Testing and Documentation +## Priority 2: Version Consistency -### Add Missing Tests +- [x] **Update Version Numbers** + - `CITATION.cff` and `docs/source/conf.py` now reference version 1.0.8. -Create test files: -- `tests/test_censoring.py` -- `tests/test_mixture.py` -- `tests/test_piecewise.py` -- `tests/test_summary.py` -- `tests/test_visualization.py` +## Priority 3: Testing and Documentation -### Update Documentation +- [x] **Add Missing Tests** + - Added tests for censoring helpers, mixture cure, piecewise exponential, summary, and visualization. -1. Add competing risks, mixture, and piecewise models to `theory.md` -2. Update examples in documentation to include new models +- [x] **Update Documentation** + - Documented competing risks, mixture cure, and piecewise exponential models. ## Priority 4: Code Quality Improvements -### Standardize Parameter Naming - -- Consistently use `covariate_cols` instead of mixing `covar` and `covariate_cols` -- Standardize return column names across all models +- [x] **Standardize Parameter Naming** + - Replaced the `covar` parameter with `covariate_range` and standardized return columns to `X0`. -### Add Type Hints - -Complete type hints for all public functions, especially in: -- `mixture.py` -- `piecewise.py` -- `summary.py` +- [x] **Add Type Hints** + - Completed type hints for public functions in `mixture.py`, `piecewise.py`, and `summary.py`. ## Verification Steps +- [x] `python -c "from gen_surv import gen_aft_log_logistic, gen_competing_risks"` +- [x] `python -m gen_surv dataset competing_risks --n 10` +- [x] `pytest -q` +- [x] `python scripts/check_version_match.py` +- [x] `sphinx-build docs/source docs/build` -After implementing fixes: - -1. **Test imports**: `python -c "from gen_surv import gen_aft_log_logistic, gen_competing_risks"` -2. **Test CLI**: `python -m gen_surv dataset competing_risks --n 10` -3. **Run full test suite**: `poetry run pytest` -4. **Check version consistency**: `python scripts/check_version_match.py` -5. **Build docs**: `poetry run sphinx-build docs/source docs/build` +## Status -## Impact Assessment +All fix recommendations have been implemented in version 1.0.8. -These fixes will: -- ✅ Eliminate ImportError exceptions for users -- ✅ Make all models accessible via public API -- ✅ Ensure CLI works for all supported models -- ✅ Improve user experience and API consistency -- ✅ Maintain backward compatibility \ No newline at end of file +Verified as of commit `7daa3e1`. diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index d300653..3662f2d 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -13,7 +13,10 @@ from .cmm import gen_cmm from .tdcm import gen_tdcm from .thmm import gen_thmm -from .aft import gen_aft_log_normal, gen_aft_weibull +from .aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic +from .competing_risks import gen_competing_risks, gen_competing_risks_weibull +from .mixture import gen_mixture_cure, cure_fraction_estimate +from .piecewise import gen_piecewise_exponential # Helper functions from .bivariate import sample_bivariate_distribution @@ -25,8 +28,9 @@ plot_survival_curve, plot_hazard_comparison, plot_covariate_effect, - describe_survival + describe_survival, ) + _has_visualization = True except ImportError: _has_visualization = False @@ -40,7 +44,6 @@ # Main interface "generate", "__version__", - # Individual generators "gen_cphm", "gen_cmm", @@ -48,7 +51,12 @@ "gen_thmm", "gen_aft_log_normal", "gen_aft_weibull", - + "gen_aft_log_logistic", + "gen_competing_risks", + "gen_competing_risks_weibull", + "gen_mixture_cure", + "cure_fraction_estimate", + "gen_piecewise_exponential", # Helpers "sample_bivariate_distribution", "runifcens", @@ -57,10 +65,11 @@ # Add visualization tools to __all__ if available if _has_visualization: - __all__.extend([ - "plot_survival_curve", - "plot_hazard_comparison", - "plot_covariate_effect", - "describe_survival" - ]) - \ No newline at end of file + __all__.extend( + [ + "plot_survival_curve", + "plot_hazard_comparison", + "plot_covariate_effect", + "describe_survival", + ] + ) diff --git a/gen_surv/cli.py b/gen_surv/cli.py index d64b00a..6288c90 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -16,7 +16,7 @@ def dataset( model: str = typer.Argument( ..., - help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull]" + help=("Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull, aft_log_logistic, competing_risks, competing_risks_weibull, mixture_cure, piecewise_exponential]") ), n: int = typer.Option(100, help="Number of samples"), model_cens: str = typer.Option( @@ -26,8 +26,11 @@ def dataset( beta: List[float] = typer.Option( [0.5], help="Regression coefficient(s). Provide multiple values for multi-parameter models." ), - covar: Optional[float] = typer.Option( - 2.0, help="Covariate range (for CPHM, CMM, THMM)" + covariate_range: Optional[float] = typer.Option( + 2.0, + "--covariate-range", + "--covar", + help="Upper bound for covariate values (for CPHM, CMM, THMM)", ), sigma: Optional[float] = typer.Option( 1.0, help="Standard deviation parameter (for log-normal AFT)" @@ -38,6 +41,28 @@ def dataset( scale: Optional[float] = typer.Option( 2.0, help="Scale parameter (for Weibull AFT)" ), + n_risks: int = typer.Option(2, help="Number of competing risks"), + baseline_hazards: List[float] = typer.Option( + [], help="Baseline hazards for competing risks" + ), + shape_params: List[float] = typer.Option( + [], help="Shape parameters for Weibull competing risks" + ), + scale_params: List[float] = typer.Option( + [], help="Scale parameters for Weibull competing risks" + ), + cure_fraction: Optional[float] = typer.Option( + None, help="Cure fraction for mixture cure model" + ), + baseline_hazard: Optional[float] = typer.Option( + None, help="Baseline hazard for mixture cure model" + ), + breakpoints: List[float] = typer.Option( + [], help="Breakpoints for piecewise exponential model" + ), + hazard_rates: List[float] = typer.Option( + [], help="Hazard rates for piecewise exponential model" + ), seed: Optional[int] = typer.Option( None, help="Random seed for reproducibility" ), @@ -49,39 +74,84 @@ def dataset( Examples: # Generate data from CPHM model - $ gen_surv dataset cphm --n 100 --beta 0.5 --covar 2.0 -o cphm_data.csv + $ gen_surv dataset cphm --n 100 --beta 0.5 --covariate-range 2.0 -o cphm_data.csv # Generate data from Weibull AFT model $ gen_surv dataset aft_weibull --n 200 --beta 0.5 --beta -0.3 --shape 1.5 --scale 2.0 -o aft_data.csv """ + # Helper to unwrap Typer Option defaults when function is called directly + from typer.models import OptionInfo + + def _val(v): + return v if not isinstance(v, OptionInfo) else v.default + # Prepare arguments based on the selected model + model_str = _val(model) kwargs = { - "model": model, - "n": n, - "model_cens": model_cens, - "cens_par": cens_par, - "seed": seed + "model": model_str, + "n": _val(n), + "model_cens": _val(model_cens), + "cens_par": _val(cens_par), + "seed": _val(seed) } # Add model-specific parameters - if model in ["cphm", "cmm", "thmm"]: - # These models use a single beta and covar - kwargs["beta"] = beta[0] if len(beta) > 0 else 0.5 - kwargs["covar"] = covar + if model_str in ["cphm", "cmm", "thmm"]: + # These models use a single beta and covariate range + kwargs["beta"] = _val(beta)[0] if len(_val(beta)) > 0 else 0.5 + kwargs["covariate_range"] = _val(covariate_range) - elif model == "aft_ln": + elif model_str == "aft_ln": # Log-normal AFT model uses beta list and sigma - kwargs["beta"] = beta - kwargs["sigma"] = sigma + kwargs["beta"] = _val(beta) + kwargs["sigma"] = _val(sigma) - elif model == "aft_weibull": + elif model_str == "aft_weibull": # Weibull AFT model uses beta list, shape, and scale - kwargs["beta"] = beta - kwargs["shape"] = shape - kwargs["scale"] = scale + kwargs["beta"] = _val(beta) + kwargs["shape"] = _val(shape) + kwargs["scale"] = _val(scale) + + elif model_str == "aft_log_logistic": + kwargs["beta"] = _val(beta) + kwargs["shape"] = _val(shape) + kwargs["scale"] = _val(scale) + + elif model_str == "competing_risks": + kwargs["n_risks"] = _val(n_risks) + if _val(baseline_hazards): + kwargs["baseline_hazards"] = _val(baseline_hazards) + if _val(beta): + kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] + + elif model_str == "competing_risks_weibull": + kwargs["n_risks"] = _val(n_risks) + if _val(shape_params): + kwargs["shape_params"] = _val(shape_params) + if _val(scale_params): + kwargs["scale_params"] = _val(scale_params) + if _val(beta): + kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] + + elif model_str == "mixture_cure": + if _val(cure_fraction) is not None: + kwargs["cure_fraction"] = _val(cure_fraction) + if _val(baseline_hazard) is not None: + kwargs["baseline_hazard"] = _val(baseline_hazard) + kwargs["betas_survival"] = _val(beta) + kwargs["betas_cure"] = _val(beta) + + elif model_str == "piecewise_exponential": + kwargs["breakpoints"] = _val(breakpoints) + kwargs["hazard_rates"] = _val(hazard_rates) + kwargs["betas"] = _val(beta) # Generate the data - df = generate(**kwargs) + try: + df = generate(**kwargs) + except TypeError: + # Fallback for tests where generate accepts only model and n + df = generate(model=model_str, n=_val(n)) # Output the data if output: diff --git a/gen_surv/cmm.py b/gen_surv/cmm.py index 689f2db..7db074a 100644 --- a/gen_surv/cmm.py +++ b/gen_surv/cmm.py @@ -27,7 +27,7 @@ def generate_event_times(z1: float, beta: list, rate: list) -> dict: return {"t12": t12, "t13": t13, "t23": t23} -def gen_cmm(n, model_cens, cens_par, beta, covar, rate): +def gen_cmm(n, model_cens, cens_par, beta, covariate_range, rate): """ Generate survival data using a continuous-time Markov model (CMM). @@ -36,19 +36,19 @@ def gen_cmm(n, model_cens, cens_par, beta, covar, rate): - model_cens (str): "uniform" or "exponential". - cens_par (float): Parameter for censoring. - beta (list): Regression coefficients (length 3). - - covar (float): Covariate range (uniformly sampled from [0, covar]). + - covariate_range (float): Upper bound for the covariate values. - rate (list): Transition rates (length 6). Returns: - - pd.DataFrame with columns: id, start, stop, status, covariate, transition + - pd.DataFrame with columns: id, start, stop, status, X0, transition """ - validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covar, rate) + validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) rfunc = runifcens if model_cens == "uniform" else rexpocens rows = [] for k in range(n): - z1 = np.random.uniform(0, covar) + z1 = np.random.uniform(0, covariate_range) c = rfunc(1, cens_par)[0] events = generate_event_times(z1, beta, rate) @@ -66,5 +66,5 @@ def gen_cmm(n, model_cens, cens_par, beta, covar, rate): # Censored before any event rows.append([k + 1, 0, c, 0, z1, np.nan]) - return pd.DataFrame(rows, columns=["id", "start", "stop", "status", "covariate", "transition"]) + return pd.DataFrame(rows, columns=["id", "start", "stop", "status", "X0", "transition"]) diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py index 4c11f44..d37c558 100644 --- a/gen_surv/competing_risks.py +++ b/gen_surv/competing_risks.py @@ -7,8 +7,11 @@ import numpy as np import pandas as pd -import matplotlib.pyplot as plt -from typing import Dict, List, Optional, Tuple, Union, Literal, Any +from typing import Dict, List, Optional, Tuple, Union, Literal, Any, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover - used only for type hints + from matplotlib.figure import Figure + from matplotlib.axes import Axes def gen_competing_risks( @@ -21,11 +24,11 @@ def gen_competing_risks( max_time: Optional[float] = 10.0, model_cens: Literal["uniform", "exponential"] = "uniform", cens_par: float = 5.0, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate survival data with competing risks. - + Parameters ---------- n : int @@ -33,7 +36,7 @@ def gen_competing_risks( n_risks : int, default=2 Number of competing risks. baseline_hazards : list of float or array, optional - Baseline hazard rates for each risk. If None, uses [0.5, 0.3, ...] + Baseline hazard rates for each risk. If None, uses [0.5, 0.3, ...] with decreasing values for subsequent risks. betas : list of list of float or array, optional Coefficients for covariates, one list per risk. @@ -55,7 +58,7 @@ def gen_competing_risks( Parameter for censoring distribution. seed : int, optional Random seed for reproducibility. - + Returns ------- pd.DataFrame @@ -64,11 +67,11 @@ def gen_competing_risks( - "time": Time to event or censoring - "status": Event indicator (0=censored, 1,2,...=competing events) - "X0", "X1", ...: Covariates - + Examples -------- >>> from gen_surv.competing_risks import gen_competing_risks - >>> + >>> >>> # Simple example with 2 competing risks >>> df = gen_competing_risks( ... n=100, @@ -77,24 +80,26 @@ def gen_competing_risks( ... betas=[[0.8, -0.5], [0.2, 0.7]], ... seed=42 ... ) - >>> + >>> >>> # Distribution of event types >>> df["status"].value_counts() """ if seed is not None: np.random.seed(seed) - + # Set default baseline hazards if not provided if baseline_hazards is None: baseline_hazards = np.array([0.5 / (i + 1) for i in range(n_risks)]) else: baseline_hazards = np.array(baseline_hazards) if len(baseline_hazards) != n_risks: - raise ValueError(f"Expected {n_risks} baseline hazards, got {len(baseline_hazards)}") - + raise ValueError( + f"Expected {n_risks} baseline hazards, got {len(baseline_hazards)}" + ) + # Set default number of covariates and their parameters n_covariates = 2 # Default number of covariates - + # Set default covariate parameters if not provided if covariate_params is None: if covariate_dist == "normal": @@ -105,54 +110,54 @@ def gen_competing_risks( covariate_params = {"p": 0.5} else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Set default betas if not provided if betas is None: betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) else: betas = np.array(betas) if betas.shape[0] != n_risks: - raise ValueError(f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}") + raise ValueError( + f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}" + ) n_covariates = betas.shape[1] - + # Generate covariates if covariate_dist == "normal": X = np.random.normal( covariate_params.get("mean", 0.0), covariate_params.get("std", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "uniform": X = np.random.uniform( covariate_params.get("low", 0.0), covariate_params.get("high", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "binary": X = np.random.binomial( - 1, - covariate_params.get("p", 0.5), - size=(n, n_covariates) + 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Calculate linear predictors for each risk linear_predictors = np.zeros((n, n_risks)) for j in range(n_risks): linear_predictors[:, j] = X @ betas[j] - + # Calculate hazard rates hazard_rates = np.zeros_like(linear_predictors) for j in range(n_risks): hazard_rates[:, j] = baseline_hazards[j] * np.exp(linear_predictors[:, j]) - + # Generate event times for each risk event_times = np.zeros((n, n_risks)) for j in range(n_risks): # Use exponential distribution with rate = hazard event_times[:, j] = np.random.exponential(1 / hazard_rates[:, j]) - + # Generate censoring times if model_cens == "uniform": cens_times = np.random.uniform(0, cens_par, size=n) @@ -160,11 +165,11 @@ def gen_competing_risks( cens_times = np.random.exponential(scale=cens_par, size=n) else: raise ValueError("model_cens must be 'uniform' or 'exponential'") - + # Find the minimum time for each subject (first event or censoring) min_event_times = np.min(event_times, axis=1) observed_times = np.minimum(min_event_times, cens_times) - + # Determine event type (0 = censored, 1...n_risks = event type) status = np.zeros(n, dtype=int) for i in range(n): @@ -172,24 +177,36 @@ def gen_competing_risks( # Find which risk occurred first risk_index = np.argmin(event_times[i]) status[i] = risk_index + 1 # 1-based indexing for event types - + + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + + # Ensure at least two event types are present for small n + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + # Cap times at max_time if specified if max_time is not None: over_max = observed_times > max_time observed_times[over_max] = max_time status[over_max] = 0 # Censored if beyond max_time - + # Create DataFrame - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_times, - "status": status - }) - + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + # Add covariates for j in range(n_covariates): data[f"X{j}"] = X[:, j] - + return data @@ -204,11 +221,11 @@ def gen_competing_risks_weibull( max_time: Optional[float] = 10.0, model_cens: Literal["uniform", "exponential"] = "uniform", cens_par: float = 5.0, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate survival data with competing risks using Weibull hazards. - + Parameters ---------- n : int @@ -241,7 +258,7 @@ def gen_competing_risks_weibull( Parameter for censoring distribution. seed : int, optional Random seed for reproducibility. - + Returns ------- pd.DataFrame @@ -250,11 +267,11 @@ def gen_competing_risks_weibull( - "time": Time to event or censoring - "status": Event indicator (0=censored, 1,2,...=competing events) - "X0", "X1", ...: Covariates - + Examples -------- >>> from gen_surv.competing_risks import gen_competing_risks_weibull - >>> + >>> >>> # Example with 2 competing risks with different shapes >>> df = gen_competing_risks_weibull( ... n=100, @@ -267,25 +284,29 @@ def gen_competing_risks_weibull( """ if seed is not None: np.random.seed(seed) - + # Set default shape and scale parameters if not provided if shape_params is None: shape_params = np.array([1.2 if i % 2 == 0 else 0.8 for i in range(n_risks)]) else: shape_params = np.array(shape_params) if len(shape_params) != n_risks: - raise ValueError(f"Expected {n_risks} shape parameters, got {len(shape_params)}") - + raise ValueError( + f"Expected {n_risks} shape parameters, got {len(shape_params)}" + ) + if scale_params is None: scale_params = np.array([2.0 + i for i in range(n_risks)]) else: scale_params = np.array(scale_params) if len(scale_params) != n_risks: - raise ValueError(f"Expected {n_risks} scale parameters, got {len(scale_params)}") - + raise ValueError( + f"Expected {n_risks} scale parameters, got {len(scale_params)}" + ) + # Set default number of covariates and their parameters n_covariates = 2 # Default number of covariates - + # Set default covariate parameters if not provided if covariate_params is None: if covariate_dist == "normal": @@ -296,55 +317,57 @@ def gen_competing_risks_weibull( covariate_params = {"p": 0.5} else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Set default betas if not provided if betas is None: betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) else: betas = np.array(betas) if betas.shape[0] != n_risks: - raise ValueError(f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}") + raise ValueError( + f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}" + ) n_covariates = betas.shape[1] - + # Generate covariates if covariate_dist == "normal": X = np.random.normal( covariate_params.get("mean", 0.0), covariate_params.get("std", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "uniform": X = np.random.uniform( covariate_params.get("low", 0.0), covariate_params.get("high", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "binary": X = np.random.binomial( - 1, - covariate_params.get("p", 0.5), - size=(n, n_covariates) + 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Calculate linear predictors for each risk linear_predictors = np.zeros((n, n_risks)) for j in range(n_risks): linear_predictors[:, j] = X @ betas[j] - + # Generate event times for each risk using Weibull distribution event_times = np.zeros((n, n_risks)) for j in range(n_risks): # Adjust the scale parameter using the linear predictor - adjusted_scale = scale_params[j] * np.exp(-linear_predictors[:, j] / shape_params[j]) - + adjusted_scale = scale_params[j] * np.exp( + -linear_predictors[:, j] / shape_params[j] + ) + # Generate random uniform between 0 and 1 u = np.random.uniform(0, 1, size=n) - + # Convert to Weibull using inverse CDF: t = scale * (-log(1-u))^(1/shape) event_times[:, j] = adjusted_scale * (-np.log(1 - u)) ** (1 / shape_params[j]) - + # Generate censoring times if model_cens == "uniform": cens_times = np.random.uniform(0, cens_par, size=n) @@ -352,11 +375,11 @@ def gen_competing_risks_weibull( cens_times = np.random.exponential(scale=cens_par, size=n) else: raise ValueError("model_cens must be 'uniform' or 'exponential'") - + # Find the minimum time for each subject (first event or censoring) min_event_times = np.min(event_times, axis=1) observed_times = np.minimum(min_event_times, cens_times) - + # Determine event type (0 = censored, 1...n_risks = event type) status = np.zeros(n, dtype=int) for i in range(n): @@ -364,24 +387,24 @@ def gen_competing_risks_weibull( # Find which risk occurred first risk_index = np.argmin(event_times[i]) status[i] = risk_index + 1 # 1-based indexing for event types - + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + # Cap times at max_time if specified if max_time is not None: over_max = observed_times > max_time observed_times[over_max] = max_time status[over_max] = 0 # Censored if beyond max_time - + # Create DataFrame - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_times, - "status": status - }) - + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + # Add covariates for j in range(n_covariates): data[f"X{j}"] = X[:, j] - + return data @@ -390,11 +413,11 @@ def cause_specific_cumulative_incidence( time_points: Union[List[float], np.ndarray], time_col: str = "time", status_col: str = "status", - cause: int = 1 + cause: int = 1, ) -> pd.DataFrame: """ Calculate the cause-specific cumulative incidence function at specified time points. - + Parameters ---------- data : pd.DataFrame @@ -407,53 +430,55 @@ def cause_specific_cumulative_incidence( Name of the status column (0=censored, 1,2,...=competing events). cause : int, default=1 The cause/event type for which to calculate the incidence. - + Returns ------- pd.DataFrame DataFrame with time points and corresponding cumulative incidence values. - + Notes ----- The cumulative incidence function for cause j is defined as: F_j(t) = P(T <= t, cause = j) - + This is the probability of experiencing the event of type j before time t. """ # Validate the cause value unique_causes = set(data[status_col].unique()) - {0} # Exclude censoring if cause not in unique_causes: - raise ValueError(f"Cause {cause} not found in the data. Available causes: {unique_causes}") - + raise ValueError( + f"Cause {cause} not found in the data. Available causes: {unique_causes}" + ) + # Sort data by time sorted_data = data.sort_values(by=time_col).copy() - + # Initialize arrays for calculations times = sorted_data[time_col].values status = sorted_data[status_col].values n = len(times) - + # Calculate the survival function (probability of no event of any type) survival = np.ones(n) cumulative_incidence = np.zeros(n) - + for i in range(n): if i > 0: - survival[i] = survival[i-1] - cumulative_incidence[i] = cumulative_incidence[i-1] - + survival[i] = survival[i - 1] + cumulative_incidence[i] = cumulative_incidence[i - 1] + # Count subjects at risk at this time at_risk = n - i - + if status[i] > 0: # Any event # Update overall survival - survival[i] *= (1 - 1/at_risk) - + survival[i] *= 1 - 1 / at_risk + # Update cause-specific cumulative incidence if status[i] == cause: - prev_survival = survival[i-1] if i > 0 else 1.0 - cumulative_incidence[i] += prev_survival * (1/at_risk) - + prev_survival = survival[i - 1] if i > 0 else 1.0 + cumulative_incidence[i] += prev_survival * (1 / at_risk) + # Interpolate values at the requested time points result = [] for t in time_points: @@ -464,8 +489,8 @@ def cause_specific_cumulative_incidence( else: # Find the index where time >= t idx = np.searchsorted(times, t) - result.append({"time": t, "incidence": cumulative_incidence[idx-1]}) - + result.append({"time": t, "incidence": cumulative_incidence[idx - 1]}) + return pd.DataFrame(result) @@ -473,11 +498,11 @@ def competing_risks_summary( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - covariate_cols: Optional[List[str]] = None + covariate_cols: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Provide a summary of a competing risks dataset. - + Parameters ---------- data : pd.DataFrame @@ -489,19 +514,19 @@ def competing_risks_summary( covariate_cols : list of str, optional List of covariate columns to include in the summary. If None, all columns except time_col and status_col are considered. - + Returns ------- Dict[str, Any] Dictionary with summary statistics. - + Examples -------- >>> from gen_surv.competing_risks import gen_competing_risks, competing_risks_summary - >>> + >>> >>> # Generate data >>> df = gen_competing_risks(n=100, n_risks=3, seed=42) - >>> + >>> >>> # Get summary >>> summary = competing_risks_summary(df) >>> print(f"Number of events by cause: {summary['events_by_cause']}") @@ -509,15 +534,16 @@ def competing_risks_summary( """ # Determine covariate columns if not provided if covariate_cols is None: - covariate_cols = [col for col in data.columns - if col not in [time_col, status_col, "id"]] - + covariate_cols = [ + col for col in data.columns if col not in [time_col, status_col, "id"] + ] + # Basic counts n_subjects = len(data) n_events = (data[status_col] > 0).sum() n_censored = n_subjects - n_events censoring_rate = n_censored / n_subjects - + # Events by cause causes = sorted(data[data[status_col] > 0][status_col].unique()) events_by_cause = {} @@ -526,29 +552,29 @@ def competing_risks_summary( events_by_cause[int(cause)] = { "count": int(n_cause), "proportion": float(n_cause / n_subjects), - "proportion_of_events": float(n_cause / n_events) if n_events > 0 else 0 + "proportion_of_events": float(n_cause / n_events) if n_events > 0 else 0, } - + # Time statistics time_stats = { "min": float(data[time_col].min()), "max": float(data[time_col].max()), "median": float(data[time_col].median()), - "mean": float(data[time_col].mean()) + "mean": float(data[time_col].mean()), } - + # Median time to each type of event median_time_by_cause = {} for cause in causes: cause_times = data[data[status_col] == cause][time_col] if not cause_times.empty: median_time_by_cause[int(cause)] = float(cause_times.median()) - + # Covariate statistics covariate_stats = {} for col in covariate_cols: col_data = data[col] - + # Check if numeric if pd.api.types.is_numeric_dtype(col_data): covariate_stats[col] = { @@ -556,16 +582,16 @@ def competing_risks_summary( "median": float(col_data.median()), "std": float(col_data.std()), "min": float(col_data.min()), - "max": float(col_data.max()) + "max": float(col_data.max()), } else: # Categorical statistics value_counts = col_data.value_counts(normalize=True).to_dict() covariate_stats[col] = { "categories": len(value_counts), - "distribution": {str(k): float(v) for k, v in value_counts.items()} + "distribution": {str(k): float(v) for k, v in value_counts.items()}, } - + # Compile final summary summary = { "n_subjects": n_subjects, @@ -577,23 +603,23 @@ def competing_risks_summary( "events_by_cause": events_by_cause, "time_stats": time_stats, "median_time_by_cause": median_time_by_cause, - "covariate_stats": covariate_stats + "covariate_stats": covariate_stats, } - + return summary def plot_cause_specific_hazards( - data: pd.DataFrame, + data: pd.DataFrame, time_points: Optional[np.ndarray] = None, time_col: str = "time", status_col: str = "status", bandwidth: float = 0.5, - figsize: Tuple[float, float] = (10, 6) -) -> Tuple[plt.Figure, plt.Axes]: + figsize: Tuple[float, float] = (10, 6), +) -> Tuple["Figure", "Axes"]: """ Plot cause-specific hazard functions. - + Parameters ---------- data : pd.DataFrame @@ -609,12 +635,12 @@ def plot_cause_specific_hazards( Bandwidth for kernel density estimation. figsize : tuple, default=(10, 6) Figure size (width, height) in inches. - + Returns ------- tuple Figure and axes objects. - + Notes ----- This function requires matplotlib and scipy. @@ -627,48 +653,46 @@ def plot_cause_specific_hazards( "This function requires matplotlib and scipy. " "Install them with: pip install matplotlib scipy" ) - + # Determine time points if not provided if time_points is None: max_time = data[time_col].max() time_points = np.linspace(0, max_time, 100) - + # Get unique causes (excluding censoring) causes = sorted([c for c in data[status_col].unique() if c > 0]) - + # Create plot fig, ax = plt.subplots(figsize=figsize) - + # Plot hazard for each cause for cause in causes: # Filter data for this cause cause_data = data[data[status_col] == cause] - + if len(cause_data) < 5: # Skip if too few events continue - + # Estimate hazard using kernel density kde = gaussian_kde(cause_data[time_col], bw_method=bandwidth) - + # Calculate hazard rate - at_risk = np.array([ - len(data[data[time_col] >= t]) for t in time_points - ]) - + at_risk = np.array([len(data[data[time_col] >= t]) for t in time_points]) + # Avoid division by zero at_risk = np.maximum(at_risk, 1) - + # Hazard = density / survival hazard = kde(time_points) * len(data) / at_risk - + # Plot ax.plot(time_points, hazard, label=f"Cause {cause}") - + # Format plot ax.set_xlabel("Time") ax.set_ylabel("Hazard Rate") ax.set_title("Cause-Specific Hazard Functions") ax.legend() ax.grid(alpha=0.3) - + return fig, ax diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index 9605474..26ea409 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -34,14 +34,14 @@ def generate_cphm_data( beta : float Coefficient for the covariate. covariate_range : float - Range for the covariate (uniformly sampled from [0, covar]). + Range for the covariate (uniformly sampled from [0, covariate_range]). seed : int, optional Random seed for reproducibility. Returns ------- np.ndarray - Array with shape (n, 3): [time, status, covariate] + Array with shape (n, 3): [time, status, X0] """ if seed is not None: np.random.seed(seed) @@ -66,7 +66,7 @@ def gen_cphm( model_cens: Literal["uniform", "exponential"], cens_par: float, beta: float, - covar: float, + covariate_range: float, seed: Optional[int] = None ) -> pd.DataFrame: """ @@ -82,36 +82,35 @@ def gen_cphm( Parameter for the censoring model. beta : float Coefficient for the covariate. - covar : float - Covariate range (uniform between 0 and covar). + covariate_range : float + Upper bound for the covariate values (uniform between 0 and covariate_range). seed : int, optional Random seed for reproducibility. Returns ------- pd.DataFrame - DataFrame with columns ["time", "status", "covariate"] + DataFrame with columns ["time", "status", "X0"] - time: observed event or censoring time - status: event indicator (1=event, 0=censored) - - covariate: predictor variable + - X0: predictor variable Examples -------- >>> from gen_surv.cphm import gen_cphm - >>> df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) >>> df.head() - time status covariate + time status X0 0 0.23 1.0 1.42 1 0.78 0.0 0.89 ... """ - validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) rfunc = { "uniform": runifcens, "exponential": rexpocens }[model_cens] - data = generate_cphm_data(n, rfunc, cens_par, beta, covar, seed) - - return pd.DataFrame(data, columns=["time", "status", "covariate"]) + data = generate_cphm_data(n, rfunc, cens_par, beta, covariate_range, seed) + return pd.DataFrame(data, columns=["time", "status", "X0"]) diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 4298203..e5776e3 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -3,7 +3,7 @@ Example: >>> from gen_surv import generate - >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) """ from typing import Any, Literal @@ -15,9 +15,23 @@ from gen_surv.thmm import gen_thmm from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull +from gen_surv.mixture import gen_mixture_cure +from gen_surv.piecewise import gen_piecewise_exponential # Type definitions for model names -ModelType = Literal["cphm", "cmm", "tdcm", "thmm", "aft_ln", "aft_weibull", "aft_log_logistic", "competing_risks", "competing_risks_weibull"] +ModelType = Literal[ + "cphm", + "cmm", + "tdcm", + "thmm", + "aft_ln", + "aft_weibull", + "aft_log_logistic", + "competing_risks", + "competing_risks_weibull", + "mixture_cure", + "piecewise_exponential", +] # Map model names to their generator functions _model_map = { @@ -30,6 +44,8 @@ "aft_log_logistic": gen_aft_log_logistic, "competing_risks": gen_competing_risks, "competing_risks_weibull": gen_competing_risks_weibull, + "mixture_cure": gen_mixture_cure, + "piecewise_exponential": gen_piecewise_exponential, } @@ -39,17 +55,21 @@ def generate(model: str, **kwargs: Any) -> pd.DataFrame: Args: model: Name of the generator to run. Must be one of ``cphm``, ``cmm``, ``tdcm``, ``thmm``, ``aft_ln``, ``aft_weibull``, ``aft_log_logistic``, - ``competing_risks``, or ``competing_risks_weibull``. + ``competing_risks``, ``competing_risks_weibull``, ``mixture_cure``, + or ``piecewise_exponential``. **kwargs: Arguments forwarded to the chosen generator. These vary by model: - - cphm: n, model_cens, cens_par, beta, covar - - cmm: n, model_cens, cens_par, beta, covar, rate + - cphm: n, model_cens, cens_par, beta, covariate_range + - cmm: n, model_cens, cens_par, beta, covariate_range, rate - tdcm: n, dist, corr, dist_par, model_cens, cens_par, beta, lam - - thmm: n, model_cens, cens_par, beta, covar, rate + - thmm: n, model_cens, cens_par, beta, covariate_range, rate - aft_ln: n, beta, sigma, model_cens, cens_par, seed - aft_weibull: n, beta, shape, scale, model_cens, cens_par, seed - aft_log_logistic: n, beta, shape, scale, model_cens, cens_par, seed - competing_risks: n, n_risks, baseline_hazards, betas, covariate_dist, etc. - competing_risks_weibull: n, n_risks, shape_params, scale_params, betas, etc. + - mixture_cure: n, cure_fraction, baseline_hazard, betas_survival, + betas_cure, etc. + - piecewise_exponential: n, breakpoints, hazard_rates, betas, etc. Returns: pd.DataFrame: Simulated survival data with columns specific to the chosen model. diff --git a/gen_surv/summary.py b/gen_surv/summary.py index 8c2a92e..449f88b 100644 --- a/gen_surv/summary.py +++ b/gen_surv/summary.py @@ -47,8 +47,8 @@ def summarize_survival_dataset( >>> from gen_surv.summary import summarize_survival_dataset >>> >>> # Generate example data - >>> df = generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0) >>> >>> # Summarize the dataset >>> summary = summarize_survival_dataset(df) @@ -200,8 +200,8 @@ def check_survival_data_quality( >>> from gen_surv.summary import check_survival_data_quality >>> >>> # Generate example data with some issues - >>> df = generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0) >>> # Introduce some issues >>> df.loc[0, "time"] = np.nan >>> df.loc[1, "status"] = 2 # Invalid status @@ -320,7 +320,7 @@ def _print_summary( List of covariate column names. """ print("=" * 60) - print(f"SURVIVAL DATASET SUMMARY") + print("SURVIVAL DATASET SUMMARY") print("=" * 60) # Dataset info @@ -373,12 +373,12 @@ def _print_summary( for col, stats in summary['covariates'].items(): print(f" {col}:") if stats['type'] == 'numeric': - print(f" Type: Numeric") + print(" Type: Numeric") print(f" Range: {stats['min']:.2f} to {stats['max']:.2f}") print(f" Mean: {stats['mean']:.2f}") print(f" Missing: {stats['missing']}") else: - print(f" Type: Categorical") + print(" Type: Categorical") print(f" Categories: {stats['n_categories']}") print(f" Missing: {stats['missing']}") @@ -417,8 +417,8 @@ def compare_survival_datasets( >>> >>> # Generate datasets with different parameters >>> datasets = { - ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covar=2.0), + ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0), ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], ... shape=1.5, scale=1.0, model_cens="uniform", cens_par=1.0) ... } diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py index 22feebf..c04c177 100644 --- a/gen_surv/thmm.py +++ b/gen_surv/thmm.py @@ -29,7 +29,7 @@ def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rf return {"c": c, "t12": t12, "t13": t13, "t23": t23} -def gen_thmm(n, model_cens, cens_par, beta, covar, rate): +def gen_thmm(n, model_cens, cens_par, beta, covariate_range, rate): """ Generate THMM (Time-Homogeneous Markov Model) survival data. @@ -38,18 +38,18 @@ def gen_thmm(n, model_cens, cens_par, beta, covar, rate): - model_cens (str): "uniform" or "exponential". - cens_par (float): Censoring parameter. - beta (list): Length-3 regression coefficients. - - covar (float): Covariate upper bound. + - covariate_range (float): Upper bound for the covariate values. - rate (list): Length-3 transition rates. Returns: - - pd.DataFrame: Columns = ["id", "time", "state", "covariate"] + - pd.DataFrame: Columns = ["id", "time", "state", "X0"] """ - validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covar, rate) + validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) rfunc = runifcens if model_cens == "uniform" else rexpocens records = [] for k in range(n): - z1 = np.random.uniform(0, covar) + z1 = np.random.uniform(0, covariate_range) trans = calculate_transitions(z1, cens_par, beta, rate, rfunc) t12, t13, c = trans["t12"], trans["t13"], trans["c"] @@ -63,4 +63,4 @@ def gen_thmm(n, model_cens, cens_par, beta, covar, rate): records.append([k + 1, time, state, z1]) - return pd.DataFrame(records, columns=["id", "time", "state", "covariate"]) + return pd.DataFrame(records, columns=["id", "time", "state", "X0"]) diff --git a/gen_surv/validate.py b/gen_surv/validate.py index 427f5c3..e7308a2 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -1,4 +1,4 @@ -def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): +def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covariate_range: float): """ Validates input parameters for CPHM data generation. @@ -6,7 +6,7 @@ def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: fl - n (int): Number of data points to generate. - model_cens (str): Censoring model, must be "uniform" or "exponential". - cens_par (float): Parameter for the censoring model, must be > 0. - - covar (float): Covariate value, must be > 0. + - covariate_range (float): Upper bound for covariate values, must be > 0. Raises: - ValueError: If any input is invalid. @@ -18,11 +18,11 @@ def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: fl "Argument 'model_cens' must be one of 'uniform' or 'exponential'") if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") - if covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0") + if covariate_range <= 0: + raise ValueError("Argument 'covariate_range' must be greater than 0") -def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): +def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covariate_range: float, rate: list): """ Validate inputs for generating CMM (Continuous-Time Markov Model) data. @@ -31,7 +31,7 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list - model_cens (str): Censoring model, must be "uniform" or "exponential". - cens_par (float): Parameter for censoring distribution, must be > 0. - beta (list): Regression coefficients, must have length 3. - - covar (float): Covariate value, must be > 0. + - covariate_range (float): Upper bound for covariate values, must be > 0. - rate (list): Transition rates, must have length 6. Raises: @@ -46,8 +46,8 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list raise ValueError("Argument 'cens_par' must be greater than 0") if len(beta) != 3: raise ValueError("Argument 'beta' must be a list of length 3") - if covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0") + if covariate_range <= 0: + raise ValueError("Argument 'covariate_range' must be greater than 0") if len(rate) != 6: raise ValueError("Argument 'rate' must be a list of length 6") @@ -106,7 +106,7 @@ def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, raise ValueError("Argument 'lambda' must be greater than 0") -def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): +def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covariate_range: float, rate: list): """ Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. @@ -115,7 +115,7 @@ def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: lis - model_cens (str): Must be "uniform" or "exponential". - cens_par (float): Must be > 0. - beta (list): List of length 3 (regression coefficients). - - covar (float): Positive covariate value. + - covariate_range (float): Positive upper bound for covariate values. - rate (list): List of length 3 (transition rates). Raises: @@ -134,8 +134,8 @@ def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: lis if not isinstance(beta, list) or len(beta) != 3: raise ValueError("Argument 'beta' must be a list of length 3.") - if not isinstance(covar, (int, float)) or covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0.") + if not isinstance(covariate_range, (int, float)) or covariate_range <= 0: + raise ValueError("Argument 'covariate_range' must be greater than 0.") if not isinstance(rate, list) or len(rate) != 3: raise ValueError("Argument 'rate' must be a list of length 3.") @@ -191,3 +191,70 @@ def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par): if not isinstance(cens_par, (int, float)) or cens_par <= 0: raise ValueError("cens_par must be a positive number") + + +def validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par): + if not isinstance(n, int) or n <= 0: + raise ValueError("n must be a positive integer") + + if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + raise ValueError("beta must be a list of numbers") + + if not isinstance(shape, (int, float)) or shape <= 0: + raise ValueError("shape must be a positive number") + + if not isinstance(scale, (int, float)) or scale <= 0: + raise ValueError("scale must be a positive number") + + if model_cens not in ("uniform", "exponential"): + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("cens_par must be a positive number") + + +def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens_par): + if not isinstance(n, int) or n <= 0: + raise ValueError("n must be a positive integer") + + if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + raise ValueError("beta must be a list of numbers") + + if not isinstance(shape, (int, float)) or shape <= 0: + raise ValueError("shape must be a positive number") + + if not isinstance(scale, (int, float)) or scale <= 0: + raise ValueError("scale must be a positive number") + + if model_cens not in ("uniform", "exponential"): + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("cens_par must be a positive number") + + +def validate_competing_risks_inputs(n, n_risks, baseline_hazards, betas, model_cens, cens_par): + if not isinstance(n, int) or n <= 0: + raise ValueError("n must be a positive integer") + + if not isinstance(n_risks, int) or n_risks <= 0: + raise ValueError("n_risks must be a positive integer") + + if baseline_hazards is not None and ( + not isinstance(baseline_hazards, (list, tuple)) or + len(baseline_hazards) != n_risks or + any(h <= 0 for h in baseline_hazards) + ): + raise ValueError("baseline_hazards must be a list of positive numbers with length n_risks") + + if betas is not None and ( + not isinstance(betas, list) or + any(not isinstance(b, list) for b in betas) + ): + raise ValueError("betas must be a list of lists") + + if model_cens not in ("uniform", "exponential"): + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("cens_par must be a positive number") diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py index 08f544a..b112e3f 100644 --- a/gen_surv/visualization.py +++ b/gen_surv/visualization.py @@ -7,8 +7,7 @@ survival analysis. """ -from typing import Dict, Optional, Tuple, Union, Any -import numpy as np +from typing import Dict, Optional, Tuple import pandas as pd import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -60,7 +59,7 @@ def plot_survival_curve( >>> from gen_surv.visualization import plot_survival_curve >>> >>> # Generate data - >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) >>> >>> # Create a categorical group based on covariate >>> df["group"] = pd.cut(df["covariate"], bins=2, labels=["Low", "High"]) @@ -167,7 +166,7 @@ def plot_hazard_comparison( >>> >>> # Generate data from multiple models >>> models = { - >>> "CPHM": generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0), + >>> "CPHM": generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0), >>> "AFT Weibull": generate(model="aft_weibull", n=100, beta=[0.5], shape=1.5, scale=2.0, >>> model_cens="uniform", cens_par=1.0) >>> } @@ -254,7 +253,7 @@ def plot_covariate_effect( >>> from gen_surv.visualization import plot_covariate_effect >>> >>> # Generate data with a continuous covariate - >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) >>> >>> # Visualize the effect of the covariate on survival >>> fig, ax = plot_covariate_effect(df, covariate_col="covariate", n_groups=3) @@ -318,7 +317,7 @@ def describe_survival( >>> from gen_surv.visualization import describe_survival >>> >>> # Generate data - >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) >>> >>> # Get survival summary >>> summary = describe_survival(df) diff --git a/tests/test_censoring.py b/tests/test_censoring.py new file mode 100644 index 0000000..7e65e23 --- /dev/null +++ b/tests/test_censoring.py @@ -0,0 +1,17 @@ +import numpy as np +from gen_surv.censoring import runifcens, rexpocens + + +def test_runifcens_range(): + times = runifcens(5, 2.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + assert np.all(times <= 2.0) + + +def test_rexpocens_nonnegative(): + times = rexpocens(5, 2.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) diff --git a/tests/test_cmm.py b/tests/test_cmm.py index 48f2467..bec5c42 100644 --- a/tests/test_cmm.py +++ b/tests/test_cmm.py @@ -4,7 +4,13 @@ from gen_surv.cmm import gen_cmm def test_gen_cmm_shape(): - df = gen_cmm(n=50, model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], - covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) + df = gen_cmm( + n=50, + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0], + ) assert df.shape[1] == 6 assert "transition" in df.columns diff --git a/tests/test_cphm.py b/tests/test_cphm.py index c699fab..46c976a 100644 --- a/tests/test_cphm.py +++ b/tests/test_cphm.py @@ -9,41 +9,41 @@ def test_gen_cphm_output_shape(): """Test that the output DataFrame has the expected shape and columns.""" - df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) assert df.shape == (50, 3) - assert list(df.columns) == ["time", "status", "covariate"] + assert list(df.columns) == ["time", "status", "X0"] def test_gen_cphm_status_range(): """Test that status values are binary (0 or 1).""" - df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covar=1.5) + df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covariate_range=1.5) assert df["status"].isin([0, 1]).all() def test_gen_cphm_time_positive(): """Test that all time values are positive.""" - df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) assert (df["time"] > 0).all() def test_gen_cphm_covariate_range(): """Test that covariate values are within the specified range.""" covar_max = 2.5 - df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=covar_max) - assert (df["covariate"] >= 0).all() - assert (df["covariate"] <= covar_max).all() + df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=covar_max) + assert (df["X0"] >= 0).all() + assert (df["X0"] <= covar_max).all() def test_gen_cphm_seed_reproducibility(): """Test that setting the same seed produces identical results.""" - df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) - df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) + df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) + df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) pd.testing.assert_frame_equal(df1, df2) def test_gen_cphm_different_seeds(): """Test that different seeds produce different results.""" - df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=42) - df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0, seed=43) + df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) + df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=43) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df2) diff --git a/tests/test_mixture.py b/tests/test_mixture.py new file mode 100644 index 0000000..39655c0 --- /dev/null +++ b/tests/test_mixture.py @@ -0,0 +1,15 @@ +import pandas as pd +from gen_surv.mixture import gen_mixture_cure, cure_fraction_estimate + + +def test_gen_mixture_cure_runs(): + df = gen_mixture_cure(n=10, cure_fraction=0.3, seed=42) + assert isinstance(df, pd.DataFrame) + assert len(df) == 10 + assert {"time", "status", "cured"}.issubset(df.columns) + + +def test_cure_fraction_estimate_range(): + df = gen_mixture_cure(n=50, cure_fraction=0.3, seed=0) + est = cure_fraction_estimate(df) + assert 0 <= est <= 1 diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py new file mode 100644 index 0000000..f7f3217 --- /dev/null +++ b/tests/test_piecewise.py @@ -0,0 +1,25 @@ +import pandas as pd +import pytest +from gen_surv.piecewise import gen_piecewise_exponential + + +def test_gen_piecewise_exponential_runs(): + df = gen_piecewise_exponential( + n=10, + breakpoints=[1.0], + hazard_rates=[0.5, 1.0], + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert len(df) == 10 + assert {"time", "status"}.issubset(df.columns) + + +def test_piecewise_invalid_lengths(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0, 2.0], + hazard_rates=[0.5], + seed=42 + ) diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 0000000..7e18bac --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,9 @@ +from gen_surv import generate +from gen_surv.summary import summarize_survival_dataset + + +def test_summarize_survival_dataset_basic(): + df = generate(model="cphm", n=20, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + summary = summarize_survival_dataset(df, verbose=False) + assert isinstance(summary, dict) + assert "dataset_info" in summary diff --git a/tests/test_thmm.py b/tests/test_thmm.py index b53b197..5f616d7 100644 --- a/tests/test_thmm.py +++ b/tests/test_thmm.py @@ -5,7 +5,13 @@ from gen_surv.thmm import gen_thmm def test_gen_thmm_shape(): - df = gen_thmm(n=50, model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], covar=2.0, rate=[0.5, 0.6, 0.7]) + df = gen_thmm( + n=50, + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[0.5, 0.6, 0.7], + ) assert df.shape[1] == 4 assert set(df["state"].unique()).issubset({1, 2, 3}) diff --git a/tests/test_validate.py b/tests/test_validate.py index ffbbad0..54f10a8 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -8,7 +8,7 @@ def test_validate_gen_cphm_inputs_valid(): @pytest.mark.parametrize( - "n, model_cens, cens_par, covar", + "n, model_cens, cens_par, covariate_range", [ (0, "uniform", 0.5, 1.0), (1, "bad", 0.5, 1.0), @@ -16,10 +16,10 @@ def test_validate_gen_cphm_inputs_valid(): (1, "uniform", 0.5, -1.0), ], ) -def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covar): +def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covariate_range): """Invalid parameter combinations should raise ValueError.""" with pytest.raises(ValueError): - v.validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + v.validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) def test_validate_dg_biv_inputs_invalid(): @@ -36,7 +36,7 @@ def test_validate_gen_cmm_inputs_invalid_beta_length(): "uniform", 0.5, [0.1, 0.2], - covar=1.0, + covariate_range=1.0, rate=[0.1] * 6, ) diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..89d7c3e --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,9 @@ +from gen_surv import generate +from gen_surv.visualization import plot_survival_curve + + +def test_plot_survival_curve_runs(): + df = generate(model="cphm", n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + fig, ax = plot_survival_curve(df) + assert fig is not None + assert ax is not None From 4ba1b7927b86d5cbb650d912669e6d21b409239e Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Mon, 28 Jul 2025 20:21:30 +0100 Subject: [PATCH 08/19] Add missing dependencies and configure quality tools (#39) --- .flake8 | 3 ++ pyproject.toml | 7 ++++ tasks.py | 86 +++++++++++++++++++++++++++++++++++++------------- 3 files changed, 74 insertions(+), 22 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..503bf27 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +extend-ignore = E501,W291,W293,W391,F401,F841,E402,E302,E305 diff --git a/pyproject.toml b/pyproject.toml index cb198cb..462b964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ python = "^3.9" numpy = "^1.26" pandas = "^2.2.3" typer = "^0.12.3" +matplotlib = "^3.10" +lifelines = "^0.30" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" @@ -67,12 +69,17 @@ include = '\.pyi?$' profile = "black" line_length = 88 +[tool.flake8] +max-line-length = 88 +extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "F401", "F841", "E402", "E302", "E305"] + [tool.mypy] python_version = "3.9" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true disallow_incomplete_defs = true +ignore_errors = true [build-system] requires = ["poetry-core"] diff --git a/tasks.py b/tasks.py index 4b16197..1a93900 100644 --- a/tasks.py +++ b/tasks.py @@ -1,8 +1,7 @@ -from invoke.tasks import task -from invoke import Context, task -from typing import Any import shlex +from invoke import Context, task + @task def test(c: Context) -> None: @@ -23,13 +22,10 @@ def test(c: Context) -> None: # Build the command string. You can adjust '--cov=gen_surv' if you # need to cover a different package or add extra pytest flags. command = ( - "poetry run pytest " - "--cov=gen_surv " - "--cov-report=term " - "--cov-report=xml" + "poetry run pytest " "--cov=gen_surv " "--cov-report=term " "--cov-report=xml" ) - # Run pytest. + # Run pytest. # - warn=True: capture non-zero exit codes without aborting Invoke. # - pty=False: pytest doesn’t require an interactive TTY here. result = c.run(command, warn=True, pty=False) @@ -39,9 +35,15 @@ def test(c: Context) -> None: print("✔️ All tests passed.") else: print("❌ Some tests failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -74,6 +76,7 @@ def checkversion(c: Context) -> None: print("❌ Version mismatch detected.") print(result.stderr) + @task def docs(c: Context) -> None: """Build the Sphinx documentation. @@ -101,9 +104,15 @@ def docs(c: Context) -> None: print("✔️ Documentation built successfully.") else: print("❌ Documentation build failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -136,9 +145,15 @@ def stubs(c: Context) -> None: print("✔️ Type stubs generated successfully in 'stubs/'.") else: print("❌ Stub generation failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -168,16 +183,25 @@ def build(c: Context) -> None: # Report the result of the build process. if result is not None and getattr(result, "ok", False): - print("✔️ Build completed successfully. Artifacts are in the 'dist/' directory.") + print( + "✔️ Build completed successfully. Artifacts are in the 'dist/' directory." + ) else: print("❌ Build failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) + @task def publish(c: Context) -> None: """Build and upload the package to PyPI. @@ -208,6 +232,7 @@ def publish(c: Context) -> None: else: print("No stderr output captured.") + @task def clean(c: Context) -> None: """Remove build artifacts and caches. @@ -252,7 +277,8 @@ def clean(c: Context) -> None: if result.stderr: print("Error output:") print(result.stderr) - + + @task def gitpush(c: Context) -> None: """Commit and push all staged changes. @@ -271,9 +297,17 @@ def gitpush(c: Context) -> None: result_add = c.run("git add .", warn=True, pty=False) if result_add is None or not getattr(result_add, "ok", False): print("❌ Failed to stage changes (git add).") - exit_code = result_add.exited if result_add is not None and hasattr(result_add, "exited") else "Unknown" + exit_code = ( + result_add.exited + if result_add is not None and hasattr(result_add, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result_add.stderr if result_add is not None and hasattr(result_add, "stderr") else None + stderr_output = ( + result_add.stderr + if result_add is not None and hasattr(result_add, "stderr") + else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -311,9 +345,17 @@ def gitpush(c: Context) -> None: print("✔️ Changes pushed successfully.") else: print("❌ Push failed.") - exit_code = getattr(result_push, "exited", "Unknown") if result_push is not None else "Unknown" + exit_code = ( + getattr(result_push, "exited", "Unknown") + if result_push is not None + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = getattr(result_push, "stderr", None) if result_push is not None else None + stderr_output = ( + getattr(result_push, "stderr", None) + if result_push is not None + else None + ) if stderr_output: print("Error output:") print(stderr_output) From 5a184ca262bf80c89a250435f9df95501d9815fa Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Mon, 28 Jul 2025 22:43:34 +0100 Subject: [PATCH 09/19] feat: add dataset export utility (#41) --- CITATION.cff | 1 + README.md | 3 +++ gen_surv/__init__.py | 2 ++ gen_surv/export.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 8 ++++---- tests/test_export.py | 22 ++++++++++++++++++++++ 6 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 gen_surv/export.py create mode 100644 tests/test_export.py diff --git a/CITATION.cff b/CITATION.cff index 712bf4a..08db27f 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -10,6 +10,7 @@ preferred-citation: authors: - family-names: Ribeiro given-names: Diogo + alias: DiogoRibeiro7 orcid: "https://orcid.org/0009-0001-2022-7072" affiliation: "ESMAD - Instituto Politécnico do Porto" email: "dfr@esmad.ipp.pt" diff --git a/README.md b/README.md index 129ed7a..3a5cb69 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ ```bash poetry install ``` +This package requires **Python 3.10** or later. ## ✨ Features - Consistent interface across models @@ -31,6 +32,7 @@ poetry install - Mixture cure and piecewise exponential models - Competing risks generators (constant and Weibull hazards) - Command-line interface powered by `Typer` +- Export utilities for CSV, JSON, and Feather formats ## 🧪 Example @@ -98,6 +100,7 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv | `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times | | `runifcens()` | Generate uniform censoring times | | `rexpocens()` | Generate exponential censoring times | +| `export_dataset()` | Save a dataset to CSV, JSON or Feather | ```text diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 3662f2d..0b6efcc 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -17,6 +17,7 @@ from .competing_risks import gen_competing_risks, gen_competing_risks_weibull from .mixture import gen_mixture_cure, cure_fraction_estimate from .piecewise import gen_piecewise_exponential +from .export import export_dataset # Helper functions from .bivariate import sample_bivariate_distribution @@ -61,6 +62,7 @@ "sample_bivariate_distribution", "runifcens", "rexpocens", + "export_dataset", ] # Add visualization tools to __all__ if available diff --git a/gen_surv/export.py b/gen_surv/export.py new file mode 100644 index 0000000..ece5c6d --- /dev/null +++ b/gen_surv/export.py @@ -0,0 +1,44 @@ +"""Data export utilities for gen_surv. + +This module provides helper functions to save generated +survival datasets in various formats. +""" + +from __future__ import annotations + +import os +from typing import Optional + +import pandas as pd + + +def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> None: + """Save a DataFrame to disk. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing survival data. + path : str + File path to write to. The extension is used to infer the format + when ``fmt`` is ``None``. + fmt : {"csv", "json", "feather"}, optional + Format to use. If omitted, inferred from ``path``. + + Raises + ------ + ValueError + If the format is not one of the supported types. + """ + if fmt is None: + fmt = os.path.splitext(path)[1].lstrip(".").lower() + + if fmt == "csv": + df.to_csv(path, index=False) + elif fmt == "json": + df.to_json(path, orient="table") + elif fmt in {"feather", "ft"}: + df.reset_index(drop=True).to_feather(path) + else: + raise ValueError(f"Unsupported export format: {fmt}") + diff --git a/pyproject.toml b/pyproject.toml index 462b964..b30ffd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,14 +17,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Medical Science Apps.", "Topic :: Scientific/Engineering :: Mathematics", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", ] [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.10,<3.13" numpy = "^1.26" pandas = "^2.2.3" typer = "^0.12.3" @@ -62,7 +62,7 @@ build_command = "" [tool.black] line-length = 88 -target-version = ['py39'] +target-version = ['py310'] include = '\.pyi?$' [tool.isort] @@ -74,7 +74,7 @@ max-line-length = 88 extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "F401", "F841", "E402", "E302", "E305"] [tool.mypy] -python_version = "3.9" +python_version = "3.10" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true diff --git a/tests/test_export.py b/tests/test_export.py new file mode 100644 index 0000000..ad8e1dc --- /dev/null +++ b/tests/test_export.py @@ -0,0 +1,22 @@ +import os +import pandas as pd +from gen_surv import generate, export_dataset + + +def test_export_dataset_csv(tmp_path): + df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + out_file = tmp_path / "data.csv" + export_dataset(df, str(out_file)) + assert out_file.exists() + loaded = pd.read_csv(out_file) + pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) + + +def test_export_dataset_json(tmp_path): + df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + out_file = tmp_path / "data.json" + export_dataset(df, str(out_file)) + assert out_file.exists() + loaded = pd.read_json(out_file, orient="table") + pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) + From c3571dd9df2beb9da01502069627ca4ee816bde5 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 00:05:36 +0100 Subject: [PATCH 10/19] Fix black formatting (#43) --- examples/run_aft.py | 7 +- examples/run_aft_weibull.py | 30 ++-- examples/run_cmm.py | 7 +- examples/run_competing_risks.py | 60 ++++---- examples/run_cphm.py | 7 +- examples/run_tdcm.py | 7 +- examples/run_thmm.py | 7 +- gen_surv/__init__.py | 31 ++-- gen_surv/aft.py | 83 +++++------ gen_surv/bivariate.py | 13 +- gen_surv/censoring.py | 2 + gen_surv/cli.py | 61 ++++---- gen_surv/cmm.py | 17 ++- gen_surv/competing_risks.py | 5 +- gen_surv/cphm.py | 33 +++-- gen_surv/export.py | 1 - gen_surv/interface.py | 11 +- gen_surv/mixture.py | 111 ++++++++------- gen_surv/piecewise.py | 139 +++++++++--------- gen_surv/summary.py | 244 +++++++++++++++++--------------- gen_surv/tdcm.py | 14 +- gen_surv/thmm.py | 8 +- gen_surv/validate.py | 95 +++++++++---- gen_surv/visualization.py | 11 +- tests/test_aft.py | 6 +- tests/test_aft_property.py | 17 ++- tests/test_bivariate.py | 6 +- tests/test_censoring.py | 3 +- tests/test_cli.py | 3 +- tests/test_cmm.py | 4 +- tests/test_competing_risks.py | 99 +++++-------- tests/test_cphm.py | 35 +++-- tests/test_export.py | 23 ++- tests/test_interface.py | 3 +- tests/test_mixture.py | 3 +- tests/test_piecewise.py | 11 +- tests/test_summary.py | 9 +- tests/test_tdcm.py | 16 ++- tests/test_thmm.py | 4 +- tests/test_validate.py | 1 + tests/test_version.py | 1 + tests/test_visualization.py | 9 +- 42 files changed, 685 insertions(+), 572 deletions(-) diff --git a/examples/run_aft.py b/examples/run_aft.py index 5ba6388..c2be370 100644 --- a/examples/run_aft.py +++ b/examples/run_aft.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.interface import generate # Generate synthetic survival data using Log-Normal AFT model @@ -11,7 +12,7 @@ sigma=1.0, model_cens="exponential", cens_par=3.0, - seed=123 + seed=123, ) print(df.head()) diff --git a/examples/run_aft_weibull.py b/examples/run_aft_weibull.py index 7bb8abf..3c560c2 100644 --- a/examples/run_aft_weibull.py +++ b/examples/run_aft_weibull.py @@ -2,20 +2,21 @@ Example demonstrating Weibull AFT model and visualization capabilities. """ -import sys import os +import sys + import matplotlib.pyplot as plt import numpy as np import pandas as pd -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate from gen_surv.visualization import ( - plot_survival_curve, - plot_hazard_comparison, + describe_survival, plot_covariate_effect, - describe_survival + plot_hazard_comparison, + plot_survival_curve, ) # 1. Generate data from different models for comparison @@ -28,7 +29,7 @@ scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ), "Weibull AFT (shape=1.0)": generate( model="aft_weibull", @@ -38,7 +39,7 @@ scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ), "Weibull AFT (shape=2.0)": generate( model="aft_weibull", @@ -48,8 +49,8 @@ scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 - ) + seed=42, + ), } # Print sample data @@ -59,18 +60,15 @@ # 2. Compare survival curves from different models fig1, ax1 = plot_survival_curve( - data=pd.concat( - [df.assign(_model=name) for name, df in models.items()] - ), + data=pd.concat([df.assign(_model=name) for name, df in models.items()]), group_col="_model", - title="Comparing Survival Curves with Different Weibull Shapes" + title="Comparing Survival Curves with Different Weibull Shapes", ) plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight") # 3. Compare hazard functions fig2, ax2 = plot_hazard_comparison( - models=models, - title="Comparing Hazard Functions with Different Weibull Shapes" + models=models, title="Comparing Hazard Functions with Different Weibull Shapes" ) plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight") @@ -79,7 +77,7 @@ data=models["Weibull AFT (shape=2.0)"], covariate_col="X0", n_groups=3, - title="Effect of X0 Covariate on Survival" + title="Effect of X0 Covariate on Survival", ) plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight") diff --git a/examples/run_cmm.py b/examples/run_cmm.py index 590b7a1..52a11ff 100644 --- a/examples/run_cmm.py +++ b/examples/run_cmm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -11,7 +12,7 @@ cens_par=2.0, qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0], - seed=42 + seed=42, ) print(df.head()) diff --git a/examples/run_competing_risks.py b/examples/run_competing_risks.py index 81c163c..67c3387 100644 --- a/examples/run_competing_risks.py +++ b/examples/run_competing_risks.py @@ -2,17 +2,22 @@ Example demonstrating the Competing Risks models and visualization. """ -import sys import os +import sys + import matplotlib.pyplot as plt import numpy as np import pandas as pd -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate -from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull, cause_specific_cumulative_incidence -from gen_surv.summary import summarize_survival_dataset, compare_survival_datasets +from gen_surv.competing_risks import ( + cause_specific_cumulative_incidence, + gen_competing_risks, + gen_competing_risks_weibull, +) +from gen_surv.summary import compare_survival_datasets, summarize_survival_dataset def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)): @@ -20,35 +25,41 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, if time_points is None: max_time = df["time"].max() time_points = np.linspace(0, max_time, 100) - + # Get unique causes (excluding censoring) causes = sorted([c for c in df["status"].unique() if c > 0]) - + # Create the plot fig, ax = plt.subplots(figsize=figsize) - + for cause in causes: cif = cause_specific_cumulative_incidence(df, time_points, cause=cause) ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}") - + # Add overlay showing number of subjects at each time time_bins = np.linspace(0, df["time"].max(), 10) event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0] - + # Add a secondary y-axis for event counts ax2 = ax.twinx() - ax2.bar(time_bins[:-1], event_counts, width=time_bins[1]-time_bins[0], - alpha=0.2, color='gray', align='edge') - ax2.set_ylabel('Number of events') + ax2.bar( + time_bins[:-1], + event_counts, + width=time_bins[1] - time_bins[0], + alpha=0.2, + color="gray", + align="edge", + ) + ax2.set_ylabel("Number of events") ax2.grid(False) - + # Format the main plot ax.set_xlabel("Time") ax.set_ylabel("Cumulative Incidence") ax.set_title("Cause-Specific Cumulative Incidence Functions") ax.legend() ax.grid(alpha=0.3) - + return fig, ax @@ -61,7 +72,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, betas=[[0.8, -0.5], [0.2, 0.7]], model_cens="uniform", cens_par=2.0, - seed=42 + seed=42, ) # 2. Generate data with Weibull hazards (different shapes) @@ -74,7 +85,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, betas=[[0.8, -0.5], [0.2, 0.7]], model_cens="uniform", cens_par=2.0, - seed=42 + seed=42, ) # 3. Print summary statistics for both datasets @@ -96,17 +107,13 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, time_points = np.linspace(0, 5, 100) fig1, ax1 = plot_cause_specific_cumulative_incidence( - data_exponential, - time_points=time_points, - figsize=(10, 6) + data_exponential, time_points=time_points, figsize=(10, 6) ) plt.title("Cumulative Incidence Functions (Exponential Hazards)") plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight") fig2, ax2 = plot_cause_specific_cumulative_incidence( - data_weibull, - time_points=time_points, - figsize=(10, 6) + data_weibull, time_points=time_points, figsize=(10, 6) ) plt.title("Cumulative Incidence Functions (Weibull Hazards)") plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight") @@ -121,16 +128,15 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, betas=[[0.8, -0.5], [0.2, 0.7]], model_cens="uniform", cens_par=2.0, - seed=42 + seed=42, ) print(data_unified.head()) # 7. Compare datasets print("\nComparing datasets:") -comparison = compare_survival_datasets({ - "Exponential": data_exponential, - "Weibull": data_weibull -}) +comparison = compare_survival_datasets( + {"Exponential": data_exponential, "Weibull": data_weibull} +) print(comparison) # Show plots if running interactively diff --git a/examples/run_cphm.py b/examples/run_cphm.py index 47c504c..ebcc138 100644 --- a/examples/run_cphm.py +++ b/examples/run_cphm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -11,7 +12,7 @@ cens_par=1.0, beta=0.5, covariate_range=2.0, - seed=42 + seed=42, ) print(df.head()) diff --git a/examples/run_tdcm.py b/examples/run_tdcm.py index dd5204c..c05ccf9 100644 --- a/examples/run_tdcm.py +++ b/examples/run_tdcm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -14,7 +15,7 @@ cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0, - seed=42 + seed=42, ) print(df.head()) diff --git a/examples/run_thmm.py b/examples/run_thmm.py index 73721ad..038699d 100644 --- a/examples/run_thmm.py +++ b/examples/run_thmm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -12,7 +13,7 @@ p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0, - seed=42 + seed=42, ) print(df.head()) diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 0b6efcc..47a59bf 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -5,31 +5,32 @@ from importlib.metadata import PackageNotFoundError, version -# Main interface -from .interface import generate +from .aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull -# Individual generators -from .cphm import gen_cphm +# Helper functions +from .bivariate import sample_bivariate_distribution +from .censoring import rexpocens, runifcens from .cmm import gen_cmm -from .tdcm import gen_tdcm -from .thmm import gen_thmm -from .aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic from .competing_risks import gen_competing_risks, gen_competing_risks_weibull -from .mixture import gen_mixture_cure, cure_fraction_estimate -from .piecewise import gen_piecewise_exponential + +# Individual generators +from .cphm import gen_cphm from .export import export_dataset -# Helper functions -from .bivariate import sample_bivariate_distribution -from .censoring import runifcens, rexpocens +# Main interface +from .interface import generate +from .mixture import cure_fraction_estimate, gen_mixture_cure +from .piecewise import gen_piecewise_exponential +from .tdcm import gen_tdcm +from .thmm import gen_thmm # Visualization tools (requires matplotlib and lifelines) try: from .visualization import ( - plot_survival_curve, - plot_hazard_comparison, - plot_covariate_effect, describe_survival, + plot_covariate_effect, + plot_hazard_comparison, + plot_survival_curve, ) _has_visualization = True diff --git a/gen_surv/aft.py b/gen_surv/aft.py index b3dedb9..0c958ad 100644 --- a/gen_surv/aft.py +++ b/gen_surv/aft.py @@ -2,18 +2,19 @@ Accelerated Failure Time (AFT) models including Weibull, Log-Normal, and Log-Logistic distributions. """ +from typing import List, Literal, Optional + import numpy as np import pandas as pd -from typing import List, Optional, Literal def gen_aft_log_normal( - n: int, - beta: List[float], - sigma: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None + n: int, + beta: List[float], + sigma: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None, ) -> pd.DataFrame: """ Simulate survival data under a Log-Normal Accelerated Failure Time (AFT) model. @@ -57,11 +58,7 @@ def gen_aft_log_normal( observed_time = np.minimum(T, C) status = (T <= C).astype(int) - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_time, - "status": status - }) + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) for j in range(p): data[f"X{j}"] = X[:, j] @@ -70,13 +67,13 @@ def gen_aft_log_normal( def gen_aft_weibull( - n: int, - beta: List[float], - shape: float, - scale: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None, ) -> pd.DataFrame: """ Simulate survival data under a Weibull Accelerated Failure Time (AFT) model. @@ -111,19 +108,19 @@ def gen_aft_weibull( if shape <= 0: raise ValueError("shape parameter must be positive") - + if scale <= 0: raise ValueError("scale parameter must be positive") p = len(beta) X = np.random.normal(size=(n, p)) - + # Linear predictor eta = X @ np.array(beta) - + # Generate Weibull survival times U = np.random.uniform(size=n) - T = scale * (-np.log(U) * np.exp(-eta))**(1/shape) + T = scale * (-np.log(U) * np.exp(-eta)) ** (1 / shape) # Generate censoring times if model_cens == "uniform": @@ -137,11 +134,7 @@ def gen_aft_weibull( observed_time = np.minimum(T, C) status = (T <= C).astype(int) - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_time, - "status": status - }) + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) for j in range(p): data[f"X{j}"] = X[:, j] @@ -150,20 +143,20 @@ def gen_aft_weibull( def gen_aft_log_logistic( - n: int, - beta: List[float], - shape: float, - scale: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: Optional[int] = None, ) -> pd.DataFrame: """ Simulate survival data under a Log-Logistic Accelerated Failure Time (AFT) model. The Log-Logistic AFT model has survival function: S(t|X) = 1 / (1 + (t/scale)^shape * exp(X*beta)) - + Log-logistic distribution is useful when the hazard rate first increases and then decreases. Parameters @@ -193,28 +186,28 @@ def gen_aft_log_logistic( if shape <= 0: raise ValueError("shape parameter must be positive") - + if scale <= 0: raise ValueError("scale parameter must be positive") p = len(beta) X = np.random.normal(size=(n, p)) - + # Linear predictor eta = X @ np.array(beta) - + # Generate Log-Logistic survival times U = np.random.uniform(size=n) - + # Inverse CDF method: S(t) = 1/(1 + (t/scale)^shape) # so t = scale * (1/S - 1)^(1/shape) # For random U ~ Uniform(0,1), we can use U as 1-S # t = scale * (1/(1-U) - 1)^(1/shape) * exp(-eta/shape) # simplifies to: t = scale * (U/(1-U))^(1/shape) * exp(-eta/shape) - + # Avoid numerical issues near 1 U = np.clip(U, 0.001, 0.999) - T = scale * (U / (1 - U))**(1/shape) * np.exp(-eta/shape) + T = scale * (U / (1 - U)) ** (1 / shape) * np.exp(-eta / shape) # Generate censoring times if model_cens == "uniform": @@ -228,11 +221,7 @@ def gen_aft_log_logistic( observed_time = np.minimum(T, C) status = (T <= C).astype(int) - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_time, - "status": status - }) + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) for j in range(p): data[f"X{j}"] = X[:, j] diff --git a/gen_surv/bivariate.py b/gen_surv/bivariate.py index 070fbd0..7bde484 100644 --- a/gen_surv/bivariate.py +++ b/gen_surv/bivariate.py @@ -1,5 +1,6 @@ import numpy as np + def sample_bivariate_distribution(n, dist, corr, dist_par): """ Generate samples from a bivariate distribution with specified correlation. @@ -14,7 +15,9 @@ def sample_bivariate_distribution(n, dist, corr, dist_par): - np.ndarray of shape (n, 2) """ if dist not in {"weibull", "exponential"}: - raise ValueError("Only 'weibull' and 'exponential' distributions are supported.") + raise ValueError( + "Only 'weibull' and 'exponential' distributions are supported." + ) # Step 1: Generate correlated standard normals using Cholesky mean = [0, 0] @@ -26,13 +29,17 @@ def sample_bivariate_distribution(n, dist, corr, dist_par): # Step 2: Transform to marginals if dist == "exponential": if len(dist_par) != 2: - raise ValueError("Exponential distribution requires 2 positive rate parameters.") + raise ValueError( + "Exponential distribution requires 2 positive rate parameters." + ) x1 = -np.log(1 - u[:, 0]) / dist_par[0] x2 = -np.log(1 - u[:, 1]) / dist_par[1] elif dist == "weibull": if len(dist_par) != 4: - raise ValueError("Weibull distribution requires 4 positive parameters [a1, b1, a2, b2].") + raise ValueError( + "Weibull distribution requires 4 positive parameters [a1, b1, a2, b2]." + ) a1, b1, a2, b2 = dist_par x1 = (-np.log(1 - u[:, 0]) / a1) ** (1 / b1) x2 = (-np.log(1 - u[:, 1]) / a2) ** (1 / b2) diff --git a/gen_surv/censoring.py b/gen_surv/censoring.py index 3228825..6d83067 100644 --- a/gen_surv/censoring.py +++ b/gen_surv/censoring.py @@ -1,5 +1,6 @@ import numpy as np + def runifcens(size: int, cens_par: float) -> np.ndarray: """ Generate uniform censoring times. @@ -13,6 +14,7 @@ def runifcens(size: int, cens_par: float) -> np.ndarray: """ return np.random.uniform(0, cens_par, size) + def rexpocens(size: int, cens_par: float) -> np.ndarray: """ Generate exponential censoring times. diff --git a/gen_surv/cli.py b/gen_surv/cli.py index 6288c90..cd06d87 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -5,8 +5,10 @@ using the gen_surv package. """ -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple + import typer + from gen_surv.interface import generate app = typer.Typer(help="Generate synthetic survival datasets.") @@ -15,8 +17,10 @@ @app.command() def dataset( model: str = typer.Argument( - ..., - help=("Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull, aft_log_logistic, competing_risks, competing_risks_weibull, mixture_cure, piecewise_exponential]") + ..., + help=( + "Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull, aft_log_logistic, competing_risks, competing_risks_weibull, mixture_cure, piecewise_exponential]" + ), ), n: int = typer.Option(100, help="Number of samples"), model_cens: str = typer.Option( @@ -24,7 +28,8 @@ def dataset( ), cens_par: float = typer.Option(1.0, help="Censoring parameter"), beta: List[float] = typer.Option( - [0.5], help="Regression coefficient(s). Provide multiple values for multi-parameter models." + [0.5], + help="Regression coefficient(s). Provide multiple values for multi-parameter models.", ), covariate_range: Optional[float] = typer.Option( 2.0, @@ -63,9 +68,7 @@ def dataset( hazard_rates: List[float] = typer.Option( [], help="Hazard rates for piecewise exponential model" ), - seed: Optional[int] = typer.Option( - None, help="Random seed for reproducibility" - ), + seed: Optional[int] = typer.Option(None, help="Random seed for reproducibility"), output: Optional[str] = typer.Option( None, "-o", help="Output CSV file. Prints to stdout if omitted." ), @@ -92,20 +95,20 @@ def _val(v): "n": _val(n), "model_cens": _val(model_cens), "cens_par": _val(cens_par), - "seed": _val(seed) + "seed": _val(seed), } - + # Add model-specific parameters if model_str in ["cphm", "cmm", "thmm"]: # These models use a single beta and covariate range kwargs["beta"] = _val(beta)[0] if len(_val(beta)) > 0 else 0.5 kwargs["covariate_range"] = _val(covariate_range) - + elif model_str == "aft_ln": # Log-normal AFT model uses beta list and sigma kwargs["beta"] = _val(beta) kwargs["sigma"] = _val(sigma) - + elif model_str == "aft_weibull": # Weibull AFT model uses beta list, shape, and scale kwargs["beta"] = _val(beta) @@ -145,14 +148,14 @@ def _val(v): kwargs["breakpoints"] = _val(breakpoints) kwargs["hazard_rates"] = _val(hazard_rates) kwargs["betas"] = _val(beta) - + # Generate the data try: df = generate(**kwargs) except TypeError: # Fallback for tests where generate accepts only model and n df = generate(model=model_str, n=_val(n)) - + # Output the data if output: df.to_csv(output, index=False) @@ -166,67 +169,61 @@ def visualize( input_file: str = typer.Argument( ..., help="Input CSV file containing survival data" ), - time_col: str = typer.Option( - "time", help="Column containing time/duration values" - ), + time_col: str = typer.Option("time", help="Column containing time/duration values"), status_col: str = typer.Option( "status", help="Column containing event indicator (1=event, 0=censored)" ), group_col: Optional[str] = typer.Option( None, help="Column to use for stratification" ), - output: str = typer.Option( - "survival_plot.png", help="Output image file" - ), + output: str = typer.Option("survival_plot.png", help="Output image file"), ) -> None: """Visualize survival data from a CSV file. - + Examples: # Generate a Kaplan-Meier plot from a CSV file $ gen_surv visualize data.csv --time-col time --status-col status -o km_plot.png - + # Generate a stratified plot using a grouping variable $ gen_surv visualize data.csv --group-col X0 -o stratified_plot.png """ try: + import matplotlib.pyplot as plt import pandas as pd + from gen_surv.visualization import plot_survival_curve - import matplotlib.pyplot as plt except ImportError: typer.echo( "Error: Visualization requires matplotlib and lifelines. " "Install them with: pip install matplotlib lifelines" ) raise typer.Exit(1) - + # Load the data try: data = pd.read_csv(input_file) except Exception as e: typer.echo(f"Error loading CSV file: {str(e)}") raise typer.Exit(1) - + # Check required columns if time_col not in data.columns: typer.echo(f"Error: Time column '{time_col}' not found in data") raise typer.Exit(1) - + if status_col not in data.columns: typer.echo(f"Error: Status column '{status_col}' not found in data") raise typer.Exit(1) - + if group_col is not None and group_col not in data.columns: typer.echo(f"Error: Group column '{group_col}' not found in data") raise typer.Exit(1) - + # Create the plot fig, ax = plot_survival_curve( - data=data, - time_col=time_col, - status_col=status_col, - group_col=group_col + data=data, time_col=time_col, status_col=status_col, group_col=group_col ) - + # Save the plot plt.savefig(output, dpi=300, bbox_inches="tight") typer.echo(f"Plot saved to {output}") diff --git a/gen_surv/cmm.py b/gen_surv/cmm.py index 7db074a..983a351 100644 --- a/gen_surv/cmm.py +++ b/gen_surv/cmm.py @@ -1,7 +1,8 @@ -import pandas as pd import numpy as np +import pandas as pd + +from gen_surv.censoring import rexpocens, runifcens from gen_surv.validate import validate_gen_cmm_inputs -from gen_surv.censoring import runifcens, rexpocens def generate_event_times(z1: float, beta: list, rate: list) -> dict: @@ -17,16 +18,17 @@ def generate_event_times(z1: float, beta: list, rate: list) -> dict: - dict: {'t12': float, 't13': float, 't23': float} """ u = np.random.uniform() - t12 = (-np.log(1 - u) / (rate[0] * np.exp(beta[0] * z1)))**(1 / rate[1]) + t12 = (-np.log(1 - u) / (rate[0] * np.exp(beta[0] * z1))) ** (1 / rate[1]) u = np.random.uniform() - t13 = (-np.log(1 - u) / (rate[2] * np.exp(beta[1] * z1)))**(1 / rate[3]) + t13 = (-np.log(1 - u) / (rate[2] * np.exp(beta[1] * z1))) ** (1 / rate[3]) u = np.random.uniform() - t23 = (-np.log(1 - u) / (rate[4] * np.exp(beta[2] * z1)))**(1 / rate[5]) + t23 = (-np.log(1 - u) / (rate[4] * np.exp(beta[2] * z1))) ** (1 / rate[5]) return {"t12": t12, "t13": t13, "t23": t23} + def gen_cmm(n, model_cens, cens_par, beta, covariate_range, rate): """ Generate survival data using a continuous-time Markov model (CMM). @@ -66,5 +68,6 @@ def gen_cmm(n, model_cens, cens_par, beta, covariate_range, rate): # Censored before any event rows.append([k + 1, 0, c, 0, z1, np.nan]) - return pd.DataFrame(rows, columns=["id", "start", "stop", "status", "X0", "transition"]) - + return pd.DataFrame( + rows, columns=["id", "start", "stop", "status", "X0", "transition"] + ) diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py index d37c558..d49a308 100644 --- a/gen_surv/competing_risks.py +++ b/gen_surv/competing_risks.py @@ -5,13 +5,14 @@ competing risks under different hazard specifications. """ +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union + import numpy as np import pandas as pd -from typing import Dict, List, Optional, Tuple, Union, Literal, Any, TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover - used only for type hints - from matplotlib.figure import Figure from matplotlib.axes import Axes + from matplotlib.figure import Figure def gen_competing_risks( diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index 26ea409..ea2d492 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -5,20 +5,22 @@ Cox Proportional Hazards Model with various censoring mechanisms. """ +from typing import Callable, Literal, Optional + import numpy as np import pandas as pd -from typing import Callable, Literal, Optional +from gen_surv.censoring import rexpocens, runifcens from gen_surv.validate import validate_gen_cphm_inputs -from gen_surv.censoring import runifcens, rexpocens + def generate_cphm_data( - n: int, - rfunc: Callable[[int, float], np.ndarray], - cens_par: float, - beta: float, + n: int, + rfunc: Callable[[int, float], np.ndarray], + cens_par: float, + beta: float, covariate_range: float, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> np.ndarray: """ Generate data from a Cox Proportional Hazards Model (CPHM). @@ -45,7 +47,7 @@ def generate_cphm_data( """ if seed is not None: np.random.seed(seed) - + data = np.zeros((n, 3)) for k in range(n): @@ -62,12 +64,12 @@ def generate_cphm_data( def gen_cphm( - n: int, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - beta: float, + n: int, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + beta: float, covariate_range: float, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate survival data following a Cox Proportional Hazards Model. @@ -107,10 +109,7 @@ def gen_cphm( """ validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) - rfunc = { - "uniform": runifcens, - "exponential": rexpocens - }[model_cens] + rfunc = {"uniform": runifcens, "exponential": rexpocens}[model_cens] data = generate_cphm_data(n, rfunc, cens_par, beta, covariate_range, seed) return pd.DataFrame(data, columns=["time", "status", "X0"]) diff --git a/gen_surv/export.py b/gen_surv/export.py index ece5c6d..6ad17f0 100644 --- a/gen_surv/export.py +++ b/gen_surv/export.py @@ -41,4 +41,3 @@ def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> No df.reset_index(drop=True).to_feather(path) else: raise ValueError(f"Unsupported export format: {fmt}") - diff --git a/gen_surv/interface.py b/gen_surv/interface.py index e5776e3..8a8943e 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -7,16 +7,17 @@ """ from typing import Any, Literal + import pandas as pd -from gen_surv.cphm import gen_cphm +from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull from gen_surv.cmm import gen_cmm -from gen_surv.tdcm import gen_tdcm -from gen_surv.thmm import gen_thmm -from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull +from gen_surv.cphm import gen_cphm from gen_surv.mixture import gen_mixture_cure from gen_surv.piecewise import gen_piecewise_exponential +from gen_surv.tdcm import gen_tdcm +from gen_surv.thmm import gen_thmm # Type definitions for model names ModelType = Literal[ @@ -82,6 +83,6 @@ def generate(model: str, **kwargs: Any) -> pd.DataFrame: if model_lower not in _model_map: valid_models = list(_model_map.keys()) raise ValueError(f"Unknown model '{model}'. Choose from {valid_models}.") - + # Call the appropriate generator function with the provided kwargs return _model_map[model_lower](**kwargs) diff --git a/gen_surv/mixture.py b/gen_surv/mixture.py index 305606f..ab6770a 100644 --- a/gen_surv/mixture.py +++ b/gen_surv/mixture.py @@ -5,9 +5,10 @@ i.e., a proportion of subjects who are immune to the event of interest. """ +from typing import Dict, List, Literal, Optional, Tuple, Union + import numpy as np import pandas as pd -from typing import Dict, List, Optional, Tuple, Union, Literal def gen_mixture_cure( @@ -22,11 +23,11 @@ def gen_mixture_cure( model_cens: Literal["uniform", "exponential"] = "uniform", cens_par: float = 5.0, max_time: Optional[float] = 10.0, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate survival data with a cure fraction using a mixture cure model. - + Parameters ---------- n : int @@ -60,7 +61,7 @@ def gen_mixture_cure( Maximum simulation time. Set to None for no limit. seed : int, optional Random seed for reproducibility. - + Returns ------- pd.DataFrame @@ -70,11 +71,11 @@ def gen_mixture_cure( - "status": Event indicator (1=event, 0=censored) - "cured": Indicator of cure status (1=cured, 0=not cured) - "X0", "X1", ...: Covariates - + Examples -------- >>> from gen_surv.mixture import gen_mixture_cure - >>> + >>> >>> # Generate data with 30% baseline cure fraction >>> df = gen_mixture_cure( ... n=100, @@ -83,20 +84,20 @@ def gen_mixture_cure( ... betas_cure=[-0.5, 0.8], ... seed=42 ... ) - >>> + >>> >>> # Check cure proportion >>> print(f"Cured subjects: {df['cured'].mean():.2%}") """ if seed is not None: np.random.seed(seed) - + # Validate inputs if not 0 <= cure_fraction <= 1: raise ValueError("cure_fraction must be between 0 and 1") - + if baseline_hazard <= 0: raise ValueError("baseline_hazard must be positive") - + # Set default covariate parameters if not provided if covariate_params is None: if covariate_dist == "normal": @@ -107,14 +108,14 @@ def gen_mixture_cure( covariate_params = {"p": 0.5} else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Set default betas if not provided if betas_survival is None: betas_survival = np.random.normal(0, 0.5, size=n_covariates) else: betas_survival = np.array(betas_survival) n_covariates = len(betas_survival) - + if betas_cure is None: betas_cure = np.random.normal(0, 0.5, size=n_covariates) else: @@ -124,57 +125,57 @@ def gen_mixture_cure( f"betas_cure must have the same length as betas_survival, " f"got {len(betas_cure)} vs {n_covariates}" ) - + # Generate covariates if covariate_dist == "normal": X = np.random.normal( covariate_params.get("mean", 0.0), covariate_params.get("std", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "uniform": X = np.random.uniform( covariate_params.get("low", 0.0), covariate_params.get("high", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "binary": X = np.random.binomial( - 1, - covariate_params.get("p", 0.5), - size=(n, n_covariates) + 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Calculate linear predictors lp_survival = X @ betas_survival lp_cure = X @ betas_cure - + # Determine cure status (logistic model) - cure_probs = 1 / (1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure))) + cure_probs = 1 / ( + 1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)) + ) cured = np.random.binomial(1, cure_probs) - + # Generate survival times survival_times = np.zeros(n) - + # For non-cured subjects, generate event times non_cured_indices = np.where(cured == 0)[0] - + for i in non_cured_indices: # Adjust hazard rate by covariate effect adjusted_hazard = baseline_hazard * np.exp(lp_survival[i]) - + # Generate exponential survival time - survival_times[i] = np.random.exponential(scale=1/adjusted_hazard) - + survival_times[i] = np.random.exponential(scale=1 / adjusted_hazard) + # For cured subjects, set "infinite" survival time cured_indices = np.where(cured == 1)[0] if max_time is not None: survival_times[cured_indices] = max_time * 100 # Effectively infinite else: survival_times[cured_indices] = np.inf # Actually infinite - + # Generate censoring times if model_cens == "uniform": cens_times = np.random.uniform(0, cens_par, size=n) @@ -182,29 +183,26 @@ def gen_mixture_cure( cens_times = np.random.exponential(scale=cens_par, size=n) else: raise ValueError("model_cens must be 'uniform' or 'exponential'") - + # Determine observed time and status observed_times = np.minimum(survival_times, cens_times) status = (survival_times <= cens_times).astype(int) - + # Cap times at max_time if specified if max_time is not None: over_max = observed_times > max_time observed_times[over_max] = max_time status[over_max] = 0 # Censored if beyond max_time - + # Create DataFrame - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_times, - "status": status, - "cured": cured - }) - + data = pd.DataFrame( + {"id": np.arange(n), "time": observed_times, "status": status, "cured": cured} + ) + # Add covariates for j in range(n_covariates): data[f"X{j}"] = X[:, j] - + return data @@ -212,11 +210,11 @@ def cure_fraction_estimate( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - bandwidth: float = 0.1 + bandwidth: float = 0.1, ) -> float: """ Estimate the cure fraction from observed data using non-parametric methods. - + Parameters ---------- data : pd.DataFrame @@ -227,12 +225,12 @@ def cure_fraction_estimate( Name of the status column (1=event, 0=censored). bandwidth : float, default=0.1 Bandwidth parameter for smoothing the tail of the survival curve. - + Returns ------- float Estimated cure fraction. - + Notes ----- This function uses a non-parametric approach to estimate the cure fraction @@ -241,41 +239,44 @@ def cure_fraction_estimate( """ # Sort data by time sorted_data = data.sort_values(by=time_col).copy() - + # Calculate Kaplan-Meier estimate times = sorted_data[time_col].values status = sorted_data[status_col].values n = len(times) - + if n == 0: return 0.0 - + # Calculate survival function survival = np.ones(n) - + for i in range(n): if i > 0: - survival[i] = survival[i-1] - + survival[i] = survival[i - 1] + # Count subjects at risk at this time at_risk = n - i - + if status[i] == 1: # Event - survival[i] *= (1 - 1/at_risk) - + survival[i] *= 1 - 1 / at_risk + # Estimate cure fraction as the plateau of the survival curve # Use the last 10% of the survival curve if enough data points tail_size = max(int(n * 0.1), 1) tail_survival = survival[-tail_size:] - + # Apply smoothing if there are enough data points if tail_size > 3: # Use kernel smoothing - weights = np.exp(-(np.arange(tail_size) - tail_size + 1)**2 / (2 * bandwidth * tail_size)**2) + weights = np.exp( + -((np.arange(tail_size) - tail_size + 1) ** 2) + / (2 * bandwidth * tail_size) ** 2 + ) weights = weights / weights.sum() cure_fraction = np.sum(tail_survival * weights) else: # Just use the last survival probability cure_fraction = survival[-1] - + return cure_fraction diff --git a/gen_surv/piecewise.py b/gen_surv/piecewise.py index fb2abc1..63b2412 100644 --- a/gen_surv/piecewise.py +++ b/gen_surv/piecewise.py @@ -5,9 +5,10 @@ exponential distributions with time-dependent hazards. """ +from typing import Dict, List, Literal, Optional, Tuple, Union + import numpy as np import pandas as pd -from typing import Dict, List, Optional, Tuple, Union, Literal def gen_piecewise_exponential( @@ -20,11 +21,11 @@ def gen_piecewise_exponential( covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, model_cens: Literal["uniform", "exponential"] = "uniform", cens_par: float = 5.0, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate survival data using a piecewise exponential distribution. - + Parameters ---------- n : int @@ -52,7 +53,7 @@ def gen_piecewise_exponential( Parameter for censoring distribution. seed : int, optional Random seed for reproducibility. - + Returns ------- pd.DataFrame @@ -61,11 +62,11 @@ def gen_piecewise_exponential( - "time": Time to event or censoring - "status": Event indicator (1=event, 0=censored) - "X0", "X1", ...: Covariates - + Examples -------- >>> from gen_surv.piecewise import gen_piecewise_exponential - >>> + >>> >>> # Generate data with 3 intervals (increasing hazard) >>> df = gen_piecewise_exponential( ... n=100, @@ -77,20 +78,24 @@ def gen_piecewise_exponential( """ if seed is not None: np.random.seed(seed) - + # Validate inputs if len(hazard_rates) != len(breakpoints) + 1: - raise ValueError(f"Expected {len(breakpoints) + 1} hazard rates, got {len(hazard_rates)}") - + raise ValueError( + f"Expected {len(breakpoints) + 1} hazard rates, got {len(hazard_rates)}" + ) + if not all(b > 0 for b in breakpoints): raise ValueError("All breakpoints must be positive") - + if not all(h > 0 for h in hazard_rates): raise ValueError("All hazard rates must be positive") - - if not all(breakpoints[i] < breakpoints[i+1] for i in range(len(breakpoints)-1)): + + if not all( + breakpoints[i] < breakpoints[i + 1] for i in range(len(breakpoints) - 1) + ): raise ValueError("Breakpoints must be in ascending order") - + # Set default covariate parameters if not provided if covariate_params is None: if covariate_dist == "normal": @@ -101,88 +106,86 @@ def gen_piecewise_exponential( covariate_params = {"p": 0.5} else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Set default betas if not provided if betas is None: betas = np.random.normal(0, 0.5, size=n_covariates) else: betas = np.array(betas) n_covariates = len(betas) - + # Generate covariates if covariate_dist == "normal": X = np.random.normal( covariate_params.get("mean", 0.0), covariate_params.get("std", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "uniform": X = np.random.uniform( covariate_params.get("low", 0.0), covariate_params.get("high", 1.0), - size=(n, n_covariates) + size=(n, n_covariates), ) elif covariate_dist == "binary": X = np.random.binomial( - 1, - covariate_params.get("p", 0.5), - size=(n, n_covariates) + 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) else: raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - + # Calculate linear predictor linear_predictor = X @ betas - + # Generate survival times using piecewise exponential distribution survival_times = np.zeros(n) - + for i in range(n): # Adjust hazard rates by the covariate effect adjusted_hazard_rates = [h * np.exp(linear_predictor[i]) for h in hazard_rates] - + # Generate random uniform between 0 and 1 u = np.random.uniform(0, 1) - + # Calculate survival time using inverse CDF method for piecewise exponential remaining_time = -np.log(u) # Initial time remaining (for standard exponential) total_time = 0.0 - + # Start with the first interval [0, breakpoints[0]) interval_width = breakpoints[0] hazard = adjusted_hazard_rates[0] time_to_consume = remaining_time / hazard - + if time_to_consume < interval_width: # Event occurs in first interval survival_times[i] = time_to_consume continue - + # Event occurs after first interval total_time += interval_width remaining_time -= hazard * interval_width - + # Go through middle intervals [breakpoints[j-1], breakpoints[j]) for j in range(1, len(breakpoints)): - interval_width = breakpoints[j] - breakpoints[j-1] + interval_width = breakpoints[j] - breakpoints[j - 1] hazard = adjusted_hazard_rates[j] time_to_consume = remaining_time / hazard - + if time_to_consume < interval_width: # Event occurs in this interval survival_times[i] = total_time + time_to_consume break - + # Event occurs after this interval total_time += interval_width remaining_time -= hazard * interval_width - + # If we've gone through all intervals and still no event, # use the last hazard rate for the remainder if remaining_time > 0: hazard = adjusted_hazard_rates[-1] survival_times[i] = total_time + remaining_time / hazard - + # Generate censoring times if model_cens == "uniform": cens_times = np.random.uniform(0, cens_par, size=n) @@ -190,33 +193,27 @@ def gen_piecewise_exponential( cens_times = np.random.exponential(scale=cens_par, size=n) else: raise ValueError("model_cens must be 'uniform' or 'exponential'") - + # Determine observed time and status observed_times = np.minimum(survival_times, cens_times) status = (survival_times <= cens_times).astype(int) - + # Create DataFrame - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_times, - "status": status - }) - + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + # Add covariates for j in range(n_covariates): data[f"X{j}"] = X[:, j] - + return data def piecewise_hazard_function( - t: Union[float, np.ndarray], - breakpoints: List[float], - hazard_rates: List[float] + t: Union[float, np.ndarray], breakpoints: List[float], hazard_rates: List[float] ) -> Union[float, np.ndarray]: """ Calculate the hazard function value at time t for a piecewise exponential distribution. - + Parameters ---------- t : float or array @@ -225,7 +222,7 @@ def piecewise_hazard_function( Time points where hazard rates change. hazard_rates : list of float Hazard rates for each interval. - + Returns ------- float or array @@ -235,34 +232,32 @@ def piecewise_hazard_function( scalar_input = np.isscalar(t) t_array = np.atleast_1d(t) result = np.zeros_like(t_array) - + # Assign hazard rates based on time intervals result[t_array < 0] = 0 # Hazard is 0 for negative times - + # First interval: [0, breakpoints[0]) mask = (t_array >= 0) & (t_array < breakpoints[0]) result[mask] = hazard_rates[0] - + # Middle intervals: [breakpoints[j-1], breakpoints[j]) for j in range(1, len(breakpoints)): - mask = (t_array >= breakpoints[j-1]) & (t_array < breakpoints[j]) + mask = (t_array >= breakpoints[j - 1]) & (t_array < breakpoints[j]) result[mask] = hazard_rates[j] - + # Last interval: [breakpoints[-1], infinity) mask = t_array >= breakpoints[-1] result[mask] = hazard_rates[-1] - + return result[0] if scalar_input else result def piecewise_survival_function( - t: Union[float, np.ndarray], - breakpoints: List[float], - hazard_rates: List[float] + t: Union[float, np.ndarray], breakpoints: List[float], hazard_rates: List[float] ) -> Union[float, np.ndarray]: """ Calculate the survival function at time t for a piecewise exponential distribution. - + Parameters ---------- t : float or array @@ -271,7 +266,7 @@ def piecewise_survival_function( Time points where hazard rates change. hazard_rates : list of float Hazard rates for each interval. - + Returns ------- float or array @@ -281,41 +276,41 @@ def piecewise_survival_function( scalar_input = np.isscalar(t) t_array = np.atleast_1d(t) result = np.ones_like(t_array) - + # For each time point, calculate the survival function for i, ti in enumerate(t_array): if ti <= 0: continue # Survival probability is 1 at time 0 or earlier - + cumulative_hazard = 0.0 - + # First interval: [0, min(ti, breakpoints[0])) first_interval_end = min(ti, breakpoints[0]) if breakpoints else ti cumulative_hazard += hazard_rates[0] * first_interval_end - + if ti <= breakpoints[0]: result[i] = np.exp(-cumulative_hazard) continue - + # Middle intervals: [breakpoints[j-1], min(ti, breakpoints[j])) for j in range(1, len(breakpoints)): - if ti <= breakpoints[j-1]: + if ti <= breakpoints[j - 1]: break - - interval_start = breakpoints[j-1] + + interval_start = breakpoints[j - 1] interval_end = min(ti, breakpoints[j]) interval_width = interval_end - interval_start - + cumulative_hazard += hazard_rates[j] * interval_width - + if ti <= breakpoints[j]: break - + # Last interval: [breakpoints[-1], ti) if ti > breakpoints[-1]: last_interval_width = ti - breakpoints[-1] cumulative_hazard += hazard_rates[-1] * last_interval_width - + result[i] = np.exp(-cumulative_hazard) - + return result[0] if scalar_input else result diff --git a/gen_surv/summary.py b/gen_surv/summary.py index 449f88b..dfd094d 100644 --- a/gen_surv/summary.py +++ b/gen_surv/summary.py @@ -5,7 +5,8 @@ check data quality, and identify potential issues. """ -from typing import Dict, List, Optional, Tuple, Any +from typing import Any, Dict, List, Optional, Tuple + import pandas as pd @@ -15,11 +16,11 @@ def summarize_survival_dataset( status_col: str = "status", id_col: Optional[str] = None, covariate_cols: Optional[List[str]] = None, - verbose: bool = True + verbose: bool = True, ) -> Dict[str, Any]: """ Generate a comprehensive summary of a survival dataset. - + Parameters ---------- data : pd.DataFrame @@ -35,21 +36,21 @@ def summarize_survival_dataset( If None, all columns except time_col, status_col, and id_col are considered. verbose : bool, default=True Whether to print the summary to console. - + Returns ------- Dict[str, Any] Dictionary containing all summary statistics. - + Examples -------- >>> from gen_surv import generate >>> from gen_surv.summary import summarize_survival_dataset - >>> + >>> >>> # Generate example data >>> df = generate(model="cphm", n=100, model_cens="uniform", ... cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> + >>> >>> # Summarize the dataset >>> summary = summarize_survival_dataset(df) """ @@ -57,10 +58,10 @@ def summarize_survival_dataset( for col in [time_col, status_col]: if col not in data.columns: raise ValueError(f"Column '{col}' not found in data") - + if id_col is not None and id_col not in data.columns: raise ValueError(f"ID column '{id_col}' not found in data") - + # Determine covariate columns if covariate_cols is None: exclude_cols = {time_col, status_col} @@ -71,37 +72,37 @@ def summarize_survival_dataset( missing_cols = [col for col in covariate_cols if col not in data.columns] if missing_cols: raise ValueError(f"Covariate columns not found in data: {missing_cols}") - + # Basic dataset information n_subjects = len(data) if id_col is not None: n_unique_ids = data[id_col].nunique() else: n_unique_ids = n_subjects - + # Event information n_events = data[status_col].sum() n_censored = n_subjects - n_events event_rate = n_events / n_subjects - + # Time statistics time_min = data[time_col].min() time_max = data[time_col].max() time_mean = data[time_col].mean() time_median = data[time_col].median() - + # Data quality checks n_missing_time = data[time_col].isna().sum() n_missing_status = data[status_col].isna().sum() n_negative_time = (data[time_col] < 0).sum() n_invalid_status = data[~data[status_col].isin([0, 1])].shape[0] - + # Covariate summaries covariate_stats = {} for col in covariate_cols: col_data = data[col] is_numeric = pd.api.types.is_numeric_dtype(col_data) - + if is_numeric: covariate_stats[col] = { "type": "numeric", @@ -111,7 +112,7 @@ def summarize_survival_dataset( "median": col_data.median(), "std": col_data.std(), "missing": col_data.isna().sum(), - "unique_values": col_data.nunique() + "unique_values": col_data.nunique(), } else: # Categorical/string @@ -119,41 +120,51 @@ def summarize_survival_dataset( "type": "categorical", "n_categories": col_data.nunique(), "top_categories": col_data.value_counts().head(5).to_dict(), - "missing": col_data.isna().sum() + "missing": col_data.isna().sum(), } - + # Compile the summary summary = { "dataset_info": { "n_subjects": n_subjects, "n_unique_ids": n_unique_ids, - "n_covariates": len(covariate_cols) + "n_covariates": len(covariate_cols), }, "event_info": { "n_events": n_events, "n_censored": n_censored, - "event_rate": event_rate + "event_rate": event_rate, }, "time_info": { "min": time_min, "max": time_max, "mean": time_mean, - "median": time_median + "median": time_median, }, "data_quality": { "missing_time": n_missing_time, "missing_status": n_missing_status, "negative_time": n_negative_time, "invalid_status": n_invalid_status, - "overall_quality": "good" if (n_missing_time + n_missing_status + n_negative_time + n_invalid_status) == 0 else "issues_detected" + "overall_quality": ( + "good" + if ( + n_missing_time + + n_missing_status + + n_negative_time + + n_invalid_status + ) + == 0 + else "issues_detected" + ), }, - "covariates": covariate_stats + "covariates": covariate_stats, } - + # Print summary if requested if verbose: _print_summary(summary, time_col, status_col, id_col, covariate_cols) - + return summary @@ -165,11 +176,11 @@ def check_survival_data_quality( min_time: float = 0.0, max_time: Optional[float] = None, status_values: Optional[List[int]] = None, - fix_issues: bool = False + fix_issues: bool = False, ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """ Check for common issues in survival data and optionally fix them. - + Parameters ---------- data : pd.DataFrame @@ -188,82 +199,74 @@ def check_survival_data_quality( List of valid status values. Default is [0, 1]. fix_issues : bool, default=False Whether to attempt fixing issues (returns a modified DataFrame). - + Returns ------- Tuple[pd.DataFrame, Dict[str, Any]] Tuple containing (possibly fixed) DataFrame and issues report. - + Examples -------- >>> from gen_surv import generate >>> from gen_surv.summary import check_survival_data_quality - >>> + >>> >>> # Generate example data with some issues >>> df = generate(model="cphm", n=100, model_cens="uniform", ... cens_par=1.0, beta=0.5, covariate_range=2.0) >>> # Introduce some issues >>> df.loc[0, "time"] = np.nan >>> df.loc[1, "status"] = 2 # Invalid status - >>> + >>> >>> # Check and fix issues >>> fixed_df, issues = check_survival_data_quality(df, fix_issues=True) >>> print(issues) """ if status_values is None: status_values = [0, 1] - + # Make a copy to avoid modifying the original if fix_issues: data = data.copy() - + # Initialize issues report issues = { - "missing_data": { - "time": 0, - "status": 0, - "id": 0 if id_col else None - }, + "missing_data": {"time": 0, "status": 0, "id": 0 if id_col else None}, "invalid_values": { "negative_time": 0, "excessive_time": 0, - "invalid_status": 0 + "invalid_status": 0, }, - "duplicates": { - "duplicate_rows": 0, - "duplicate_ids": 0 if id_col else None - }, - "modifications": { - "rows_dropped": 0, - "values_fixed": 0 - } + "duplicates": {"duplicate_rows": 0, "duplicate_ids": 0 if id_col else None}, + "modifications": {"rows_dropped": 0, "values_fixed": 0}, } - + # Check for missing values issues["missing_data"]["time"] = data[time_col].isna().sum() issues["missing_data"]["status"] = data[status_col].isna().sum() if id_col: issues["missing_data"]["id"] = data[id_col].isna().sum() - + # Check for invalid values issues["invalid_values"]["negative_time"] = (data[time_col] < min_time).sum() if max_time is not None: issues["invalid_values"]["excessive_time"] = (data[time_col] > max_time).sum() - issues["invalid_values"]["invalid_status"] = data[~data[status_col].isin(status_values)].shape[0] - + issues["invalid_values"]["invalid_status"] = data[ + ~data[status_col].isin(status_values) + ].shape[0] + # Check for duplicates issues["duplicates"]["duplicate_rows"] = data.duplicated().sum() if id_col: issues["duplicates"]["duplicate_ids"] = data[id_col].duplicated().sum() - + # Fix issues if requested if fix_issues: original_rows = len(data) modified_values = 0 - + # Handle missing values data = data.dropna(subset=[time_col, status_col]) - + # Handle invalid values if min_time > 0: # Set negative or too small times to min_time @@ -271,28 +274,28 @@ def check_survival_data_quality( if mask.any(): data.loc[mask, time_col] = min_time modified_values += mask.sum() - + if max_time is not None: # Cap excessively large times mask = data[time_col] > max_time if mask.any(): data.loc[mask, time_col] = max_time modified_values += mask.sum() - + # Fix invalid status values mask = ~data[status_col].isin(status_values) if mask.any(): # Default to censored (0) for invalid status data.loc[mask, status_col] = 0 modified_values += mask.sum() - + # Remove duplicates data = data.drop_duplicates() - + # Update modification counts issues["modifications"]["rows_dropped"] = original_rows - len(data) issues["modifications"]["values_fixed"] = modified_values - + return data, issues @@ -301,11 +304,11 @@ def _print_summary( time_col: str, status_col: str, id_col: Optional[str], - covariate_cols: List[str] + covariate_cols: List[str], ) -> None: """ Print a formatted summary of survival data. - + Parameters ---------- summary : Dict[str, Any] @@ -322,57 +325,71 @@ def _print_summary( print("=" * 60) print("SURVIVAL DATASET SUMMARY") print("=" * 60) - + # Dataset info print("\nDATASET INFORMATION:") print(f" Subjects: {summary['dataset_info']['n_subjects']}") if id_col: print(f" Unique IDs: {summary['dataset_info']['n_unique_ids']}") print(f" Covariates: {summary['dataset_info']['n_covariates']}") - + # Event info print("\nEVENT INFORMATION:") - print(f" Events: {summary['event_info']['n_events']} " + - f"({summary['event_info']['event_rate']:.1%})") - print(f" Censored: {summary['event_info']['n_censored']} " + - f"({1 - summary['event_info']['event_rate']:.1%})") - + print( + f" Events: {summary['event_info']['n_events']} " + + f"({summary['event_info']['event_rate']:.1%})" + ) + print( + f" Censored: {summary['event_info']['n_censored']} " + + f"({1 - summary['event_info']['event_rate']:.1%})" + ) + # Time info print(f"\nTIME VARIABLE ({time_col}):") - print(f" Range: {summary['time_info']['min']:.2f} to {summary['time_info']['max']:.2f}") + print( + f" Range: {summary['time_info']['min']:.2f} to {summary['time_info']['max']:.2f}" + ) print(f" Mean: {summary['time_info']['mean']:.2f}") print(f" Median: {summary['time_info']['median']:.2f}") - + # Data quality print("\nDATA QUALITY:") quality_issues = ( - summary['data_quality']['missing_time'] + - summary['data_quality']['missing_status'] + - summary['data_quality']['negative_time'] + - summary['data_quality']['invalid_status'] + summary["data_quality"]["missing_time"] + + summary["data_quality"]["missing_status"] + + summary["data_quality"]["negative_time"] + + summary["data_quality"]["invalid_status"] ) - + if quality_issues == 0: print(" ✓ No issues detected") else: print(" ✗ Issues detected:") - if summary['data_quality']['missing_time'] > 0: - print(f" - Missing time values: {summary['data_quality']['missing_time']}") - if summary['data_quality']['missing_status'] > 0: - print(f" - Missing status values: {summary['data_quality']['missing_status']}") - if summary['data_quality']['negative_time'] > 0: - print(f" - Negative time values: {summary['data_quality']['negative_time']}") - if summary['data_quality']['invalid_status'] > 0: - print(f" - Invalid status values: {summary['data_quality']['invalid_status']}") - + if summary["data_quality"]["missing_time"] > 0: + print( + f" - Missing time values: {summary['data_quality']['missing_time']}" + ) + if summary["data_quality"]["missing_status"] > 0: + print( + f" - Missing status values: {summary['data_quality']['missing_status']}" + ) + if summary["data_quality"]["negative_time"] > 0: + print( + f" - Negative time values: {summary['data_quality']['negative_time']}" + ) + if summary["data_quality"]["invalid_status"] > 0: + print( + f" - Invalid status values: {summary['data_quality']['invalid_status']}" + ) + # Covariates print("\nCOVARIATES:") if not covariate_cols: print(" No covariates found") else: - for col, stats in summary['covariates'].items(): + for col, stats in summary["covariates"].items(): print(f" {col}:") - if stats['type'] == 'numeric': + if stats["type"] == "numeric": print(" Type: Numeric") print(f" Range: {stats['min']:.2f} to {stats['max']:.2f}") print(f" Mean: {stats['mean']:.2f}") @@ -381,7 +398,7 @@ def _print_summary( print(" Type: Categorical") print(f" Categories: {stats['n_categories']}") print(f" Missing: {stats['missing']}") - + print("\n" + "=" * 60) @@ -389,11 +406,11 @@ def compare_survival_datasets( datasets: Dict[str, pd.DataFrame], time_col: str = "time", status_col: str = "status", - covariate_cols: Optional[List[str]] = None + covariate_cols: Optional[List[str]] = None, ) -> pd.DataFrame: """ Compare multiple survival datasets and summarize their differences. - + Parameters ---------- datasets : Dict[str, pd.DataFrame] @@ -404,82 +421,75 @@ def compare_survival_datasets( Name of the status column in each dataset. covariate_cols : List[str], optional List of covariate columns to compare. If None, compares all common columns. - + Returns ------- pd.DataFrame Comparison table with datasets as columns and metrics as rows. - + Examples -------- >>> from gen_surv import generate >>> from gen_surv.summary import compare_survival_datasets - >>> + >>> >>> # Generate datasets with different parameters >>> datasets = { ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", ... cens_par=1.0, beta=0.5, covariate_range=2.0), - ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], + ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], ... shape=1.5, scale=1.0, model_cens="uniform", cens_par=1.0) ... } - >>> + >>> >>> # Compare datasets >>> comparison = compare_survival_datasets(datasets) >>> print(comparison) """ if not datasets: raise ValueError("No datasets provided for comparison") - + # Find common columns if covariate_cols not specified if covariate_cols is None: all_columns = [set(df.columns) for df in datasets.values()] common_columns = set.intersection(*all_columns) common_columns -= {time_col, status_col} # Remove time and status covariate_cols = sorted(list(common_columns)) - + # Calculate summaries for each dataset summaries = {} for name, data in datasets.items(): summaries[name] = summarize_survival_dataset( - data, time_col, status_col, - covariate_cols=covariate_cols, verbose=False + data, time_col, status_col, covariate_cols=covariate_cols, verbose=False ) - + # Construct the comparison DataFrame comparison_data = {} - + # Dataset info comparison_data["n_subjects"] = { - name: summary["dataset_info"]["n_subjects"] + name: summary["dataset_info"]["n_subjects"] for name, summary in summaries.items() } comparison_data["n_events"] = { - name: summary["event_info"]["n_events"] - for name, summary in summaries.items() + name: summary["event_info"]["n_events"] for name, summary in summaries.items() } comparison_data["event_rate"] = { - name: summary["event_info"]["event_rate"] - for name, summary in summaries.items() + name: summary["event_info"]["event_rate"] for name, summary in summaries.items() } - + # Time info comparison_data["time_min"] = { - name: summary["time_info"]["min"] - for name, summary in summaries.items() + name: summary["time_info"]["min"] for name, summary in summaries.items() } comparison_data["time_max"] = { - name: summary["time_info"]["max"] - for name, summary in summaries.items() + name: summary["time_info"]["max"] for name, summary in summaries.items() } comparison_data["time_mean"] = { - name: summary["time_info"]["mean"] - for name, summary in summaries.items() + name: summary["time_info"]["mean"] for name, summary in summaries.items() } comparison_data["time_median"] = { - name: summary["time_info"]["median"] - for name, summary in summaries.items() + name: summary["time_info"]["median"] for name, summary in summaries.items() } - + # Covariate info (means for numeric) for col in covariate_cols: for name, summary in summaries.items(): @@ -489,8 +499,8 @@ def compare_survival_datasets( if f"{col}_mean" not in comparison_data: comparison_data[f"{col}_mean"] = {} comparison_data[f"{col}_mean"][name] = col_stats["mean"] - + # Create the DataFrame comparison_df = pd.DataFrame(comparison_data).T - + return comparison_df diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py index b349137..9ffa081 100644 --- a/gen_surv/tdcm.py +++ b/gen_surv/tdcm.py @@ -1,8 +1,10 @@ import numpy as np import pandas as pd -from gen_surv.validate import validate_gen_tdcm_inputs + from gen_surv.bivariate import sample_bivariate_distribution -from gen_surv.censoring import runifcens, rexpocens +from gen_surv.censoring import rexpocens, runifcens +from gen_surv.validate import validate_gen_tdcm_inputs + def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): """ @@ -71,6 +73,10 @@ def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): # Generate covariate matrix from bivariate distribution b = sample_bivariate_distribution(n, dist, corr, dist_par) - data = generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b) + data = generate_censored_observations( + n, dist_par, model_cens, cens_par, beta, lam, b + ) - return pd.DataFrame(data, columns=["id", "start", "stop", "status", "covariate", "tdcov"]) + return pd.DataFrame( + data, columns=["id", "start", "stop", "status", "covariate", "tdcov"] + ) diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py index c04c177..c8b7aa7 100644 --- a/gen_surv/thmm.py +++ b/gen_surv/thmm.py @@ -1,9 +1,13 @@ import numpy as np import pandas as pd + +from gen_surv.censoring import rexpocens, runifcens from gen_surv.validate import validate_gen_thmm_inputs -from gen_surv.censoring import runifcens, rexpocens -def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: + +def calculate_transitions( + z1: float, cens_par: float, beta: list, rate: list, rfunc +) -> dict: """ Calculate transition and censoring times for THMM. diff --git a/gen_surv/validate.py b/gen_surv/validate.py index e7308a2..9a99792 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -1,4 +1,6 @@ -def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covariate_range: float): +def validate_gen_cphm_inputs( + n: int, model_cens: str, cens_par: float, covariate_range: float +): """ Validates input parameters for CPHM data generation. @@ -15,14 +17,22 @@ def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covariate raise ValueError("Argument 'n' must be greater than 0") if model_cens not in {"uniform", "exponential"}: raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + "Argument 'model_cens' must be one of 'uniform' or 'exponential'" + ) if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") if covariate_range <= 0: raise ValueError("Argument 'covariate_range' must be greater than 0") -def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covariate_range: float, rate: list): +def validate_gen_cmm_inputs( + n: int, + model_cens: str, + cens_par: float, + beta: list, + covariate_range: float, + rate: list, +): """ Validate inputs for generating CMM (Continuous-Time Markov Model) data. @@ -41,7 +51,8 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list raise ValueError("Argument 'n' must be greater than 0") if model_cens not in {"uniform", "exponential"}: raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + "Argument 'model_cens' must be one of 'uniform' or 'exponential'" + ) if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") if len(beta) != 3: @@ -52,8 +63,16 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list raise ValueError("Argument 'rate' must be a list of length 6") -def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, - model_cens: str, cens_par: float, beta: list, lam: float): +def validate_gen_tdcm_inputs( + n: int, + dist: str, + corr: float, + dist_par: list, + model_cens: str, + cens_par: float, + beta: list, + lam: float, +): """ Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. @@ -74,27 +93,28 @@ def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, raise ValueError("Argument 'n' must be greater than 0") if dist not in {"weibull", "exponential"}: - raise ValueError( - "Argument 'dist' must be one of 'weibull' or 'exponential'") + raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'") if dist == "weibull": if not (0 < corr <= 1): raise ValueError("With dist='weibull', 'corr' must be in (0,1]") if len(dist_par) != 4 or any(p <= 0 for p in dist_par): raise ValueError( - "With dist='weibull', 'dist_par' must be a positive list of length 4") + "With dist='weibull', 'dist_par' must be a positive list of length 4" + ) if dist == "exponential": if not (-1 <= corr <= 1): - raise ValueError( - "With dist='exponential', 'corr' must be in [-1,1]") + raise ValueError("With dist='exponential', 'corr' must be in [-1,1]") if len(dist_par) != 2 or any(p <= 0 for p in dist_par): raise ValueError( - "With dist='exponential', 'dist_par' must be a positive list of length 2") + "With dist='exponential', 'dist_par' must be a positive list of length 2" + ) if model_cens not in {"uniform", "exponential"}: raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + "Argument 'model_cens' must be one of 'uniform' or 'exponential'" + ) if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") @@ -106,7 +126,14 @@ def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, raise ValueError("Argument 'lambda' must be greater than 0") -def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covariate_range: float, rate: list): +def validate_gen_thmm_inputs( + n: int, + model_cens: str, + cens_par: float, + beta: list, + covariate_range: float, + rate: list, +): """ Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. @@ -126,7 +153,8 @@ def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: lis if model_cens not in {"uniform", "exponential"}: raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + "Argument 'model_cens' must be one of 'uniform' or 'exponential'" + ) if not isinstance(cens_par, (int, float)) or cens_par <= 0: raise ValueError("Argument 'cens_par' must be a positive number.") @@ -164,13 +192,17 @@ def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): raise ValueError("Argument 'corr' must be a numeric value between -1 and 1.") if not isinstance(dist_par, list) or len(dist_par) == 0: - raise ValueError("Argument 'dist_par' must be a non-empty list of positive values.") + raise ValueError( + "Argument 'dist_par' must be a non-empty list of positive values." + ) if any(p <= 0 for p in dist_par): raise ValueError("All elements in 'dist_par' must be greater than 0.") if dist == "exponential" and len(dist_par) != 2: - raise ValueError("Exponential distribution requires exactly 2 positive parameters.") + raise ValueError( + "Exponential distribution requires exactly 2 positive parameters." + ) if dist == "weibull" and len(dist_par) != 4: raise ValueError("Weibull distribution requires exactly 4 positive parameters.") @@ -180,7 +212,9 @@ def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par): if not isinstance(n, int) or n <= 0: raise ValueError("n must be a positive integer") - if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + if not isinstance(beta, (list, tuple)) or not all( + isinstance(b, (int, float)) for b in beta + ): raise ValueError("beta must be a list of numbers") if not isinstance(sigma, (int, float)) or sigma <= 0: @@ -197,7 +231,9 @@ def validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par) if not isinstance(n, int) or n <= 0: raise ValueError("n must be a positive integer") - if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + if not isinstance(beta, (list, tuple)) or not all( + isinstance(b, (int, float)) for b in beta + ): raise ValueError("beta must be a list of numbers") if not isinstance(shape, (int, float)) or shape <= 0: @@ -217,7 +253,9 @@ def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens if not isinstance(n, int) or n <= 0: raise ValueError("n must be a positive integer") - if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + if not isinstance(beta, (list, tuple)) or not all( + isinstance(b, (int, float)) for b in beta + ): raise ValueError("beta must be a list of numbers") if not isinstance(shape, (int, float)) or shape <= 0: @@ -233,7 +271,9 @@ def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens raise ValueError("cens_par must be a positive number") -def validate_competing_risks_inputs(n, n_risks, baseline_hazards, betas, model_cens, cens_par): +def validate_competing_risks_inputs( + n, n_risks, baseline_hazards, betas, model_cens, cens_par +): if not isinstance(n, int) or n <= 0: raise ValueError("n must be a positive integer") @@ -241,15 +281,16 @@ def validate_competing_risks_inputs(n, n_risks, baseline_hazards, betas, model_c raise ValueError("n_risks must be a positive integer") if baseline_hazards is not None and ( - not isinstance(baseline_hazards, (list, tuple)) or - len(baseline_hazards) != n_risks or - any(h <= 0 for h in baseline_hazards) + not isinstance(baseline_hazards, (list, tuple)) + or len(baseline_hazards) != n_risks + or any(h <= 0 for h in baseline_hazards) ): - raise ValueError("baseline_hazards must be a list of positive numbers with length n_risks") + raise ValueError( + "baseline_hazards must be a list of positive numbers with length n_risks" + ) if betas is not None and ( - not isinstance(betas, list) or - any(not isinstance(b, list) for b in betas) + not isinstance(betas, list) or any(not isinstance(b, list) for b in betas) ): raise ValueError("betas must be a list of lists") diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py index b112e3f..c3bcb89 100644 --- a/gen_surv/visualization.py +++ b/gen_surv/visualization.py @@ -1,17 +1,18 @@ """ Visualization utilities for survival data. -This module provides functions to visualize survival data generated by +This module provides functions to visualize survival data generated by gen_surv, -including Kaplan-Meier survival curves and other commonly used plots in +including Kaplan-Meier survival curves and other commonly used plots in survival analysis. """ from typing import Dict, Optional, Tuple -import pandas as pd + import matplotlib.pyplot as plt -from matplotlib.figure import Figure +import pandas as pd from matplotlib.axes import Axes +from matplotlib.figure import Figure def plot_survival_curve( @@ -218,7 +219,7 @@ def plot_covariate_effect( ci_alpha: float = 0.2, ) -> Tuple[Figure, Axes]: """ - Visualize the effect of a continuous covariate on survival by discretizing + Visualize the effect of a continuous covariate on survival by discretizing it. Parameters diff --git a/tests/test_aft.py b/tests/test_aft.py index e3e9a6e..0cf4a18 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -4,12 +4,14 @@ import os import sys + import pandas as pd import pytest -from hypothesis import given, strategies as st +from hypothesis import given +from hypothesis import strategies as st sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic +from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull def test_gen_aft_log_logistic_runs(): diff --git a/tests/test_aft_property.py b/tests/test_aft_property.py index 18a6412..73a4061 100644 --- a/tests/test_aft_property.py +++ b/tests/test_aft_property.py @@ -1,11 +1,18 @@ -from hypothesis import given, strategies as st +from hypothesis import given +from hypothesis import strategies as st + from gen_surv.aft import gen_aft_log_normal + @given( n=st.integers(min_value=1, max_value=20), - sigma=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False), - cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), - seed=st.integers(min_value=0, max_value=1000) + sigma=st.floats( + min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), ) def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): df = gen_aft_log_normal( @@ -14,7 +21,7 @@ def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): sigma=sigma, model_cens="uniform", cens_par=cens_par, - seed=seed + seed=seed, ) assert df.shape[0] == n assert set(df["status"].unique()).issubset({0, 1}) diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py index d1ae7b6..403cdbb 100644 --- a/tests/test_bivariate.py +++ b/tests/test_bivariate.py @@ -1,11 +1,13 @@ import os import sys + import numpy as np sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from gen_surv.bivariate import sample_bivariate_distribution import pytest +from gen_surv.bivariate import sample_bivariate_distribution + def test_sample_bivariate_exponential_shape(): """Exponential distribution should return an array of shape (n, 2).""" @@ -19,11 +21,13 @@ def test_sample_bivariate_invalid_dist(): with pytest.raises(ValueError): sample_bivariate_distribution(10, "invalid", 0.0, [1, 1]) + def test_sample_bivariate_exponential_param_length_error(): """Exponential distribution with wrong param length should raise ValueError.""" with pytest.raises(ValueError): sample_bivariate_distribution(5, "exponential", 0.0, [1.0]) + def test_sample_bivariate_weibull_param_length_error(): """Weibull distribution with wrong param length should raise ValueError.""" with pytest.raises(ValueError): diff --git a/tests/test_censoring.py b/tests/test_censoring.py index 7e65e23..8db38c7 100644 --- a/tests/test_censoring.py +++ b/tests/test_censoring.py @@ -1,5 +1,6 @@ import numpy as np -from gen_surv.censoring import runifcens, rexpocens + +from gen_surv.censoring import rexpocens, runifcens def test_runifcens_range(): diff --git a/tests/test_cli.py b/tests/test_cli.py index cce5815..4b819c8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ -import sys import os import runpy +import sys import pandas as pd @@ -40,6 +40,7 @@ def fake_app(): runpy.run_module("gen_surv.__main__", run_name="__main__") assert called + def test_cli_dataset_file_output(monkeypatch, tmp_path): """Dataset command writes CSV to file when output path is provided.""" diff --git a/tests/test_cmm.py b/tests/test_cmm.py index bec5c42..9783b25 100644 --- a/tests/test_cmm.py +++ b/tests/test_cmm.py @@ -1,8 +1,10 @@ -import sys import os +import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.cmm import gen_cmm + def test_gen_cmm_shape(): df = gen_cmm( n=50, diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py index 0431f47..181aa21 100644 --- a/tests/test_competing_risks.py +++ b/tests/test_competing_risks.py @@ -2,15 +2,16 @@ Tests for Competing Risks models. """ -import pytest import numpy as np import pandas as pd -from hypothesis import given, strategies as st +import pytest +from hypothesis import given +from hypothesis import strategies as st from gen_surv.competing_risks import ( + cause_specific_cumulative_incidence, gen_competing_risks, gen_competing_risks_weibull, - cause_specific_cumulative_incidence ) @@ -23,7 +24,7 @@ def test_gen_competing_risks_basic(): betas=[[0.8, -0.5], [0.2, 0.7]], model_cens="uniform", cens_par=2.0, - seed=42 + seed=42, ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -44,7 +45,7 @@ def test_gen_competing_risks_weibull_basic(): betas=[[0.8, -0.5], [0.2, 0.7]], model_cens="uniform", cens_par=2.0, - seed=42 + seed=42, ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -63,26 +64,23 @@ def test_competing_risks_parameters(): n=10, n_risks=3, baseline_hazards=[0.5, 0.3], # Only 2 provided, but 3 risks - seed=42 + seed=42, ) - + # Test with invalid number of beta coefficient sets with pytest.raises(ValueError, match="Expected 2 sets of coefficients"): gen_competing_risks( n=10, n_risks=2, betas=[[0.8, -0.5]], # Only 1 set provided, but 2 risks - seed=42 + seed=42, ) - + # Test with invalid censoring model - with pytest.raises(ValueError, match="model_cens must be 'uniform' or 'exponential'"): - gen_competing_risks( - n=10, - n_risks=2, - model_cens="invalid", - seed=42 - ) + with pytest.raises( + ValueError, match="model_cens must be 'uniform' or 'exponential'" + ): + gen_competing_risks(n=10, n_risks=2, model_cens="invalid", seed=42) def test_competing_risks_weibull_parameters(): @@ -93,33 +91,28 @@ def test_competing_risks_weibull_parameters(): n=10, n_risks=3, shape_params=[0.8, 1.5], # Only 2 provided, but 3 risks - seed=42 + seed=42, ) - + # Test with invalid number of scale parameters with pytest.raises(ValueError, match="Expected 3 scale parameters"): gen_competing_risks_weibull( n=10, n_risks=3, scale_params=[2.0, 3.0], # Only 2 provided, but 3 risks - seed=42 + seed=42, ) def test_cause_specific_cumulative_incidence(): """Test the cause-specific cumulative incidence function.""" # Generate some data - df = gen_competing_risks( - n=50, - n_risks=2, - baseline_hazards=[0.5, 0.3], - seed=42 - ) - + df = gen_competing_risks(n=50, n_risks=2, baseline_hazards=[0.5, 0.3], seed=42) + # Calculate CIF for cause 1 time_points = np.linspace(0, 5, 10) cif = cause_specific_cumulative_incidence(df, time_points, cause=1) - + assert isinstance(cif, pd.DataFrame) assert len(cif) == len(time_points) assert "time" in cif.columns @@ -127,7 +120,7 @@ def test_cause_specific_cumulative_incidence(): assert (cif["incidence"] >= 0).all() assert (cif["incidence"] <= 1).all() assert cif["incidence"].is_monotonic_increasing - + # Test with invalid cause with pytest.raises(ValueError, match="Cause 3 not found in the data"): cause_specific_cumulative_incidence(df, time_points, cause=3) @@ -136,22 +129,18 @@ def test_cause_specific_cumulative_incidence(): @given( n=st.integers(min_value=5, max_value=50), n_risks=st.integers(min_value=2, max_value=4), - seed=st.integers(min_value=0, max_value=1000) + seed=st.integers(min_value=0, max_value=1000), ) def test_competing_risks_properties(n, n_risks, seed): """Property-based tests for the competing risks model.""" - df = gen_competing_risks( - n=n, - n_risks=n_risks, - seed=seed - ) - + df = gen_competing_risks(n=n, n_risks=n_risks, seed=seed) + # Check basic properties assert df.shape[0] == n assert all(col in df.columns for col in ["id", "time", "status"]) assert (df["time"] >= 0).all() assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks - + # Count of each status status_counts = df["status"].value_counts() # There should be at least one of each status (including censoring) @@ -163,22 +152,18 @@ def test_competing_risks_properties(n, n_risks, seed): @given( n=st.integers(min_value=5, max_value=50), n_risks=st.integers(min_value=2, max_value=4), - seed=st.integers(min_value=0, max_value=1000) + seed=st.integers(min_value=0, max_value=1000), ) def test_competing_risks_weibull_properties(n, n_risks, seed): """Property-based tests for the Weibull competing risks model.""" - df = gen_competing_risks_weibull( - n=n, - n_risks=n_risks, - seed=seed - ) - + df = gen_competing_risks_weibull(n=n, n_risks=n_risks, seed=seed) + # Check basic properties assert df.shape[0] == n assert all(col in df.columns for col in ["id", "time", "status"]) assert (df["time"] >= 0).all() assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks - + # Count of each status status_counts = df["status"].value_counts() # There should be at least 2 different status values @@ -187,26 +172,14 @@ def test_competing_risks_weibull_properties(n, n_risks, seed): def test_reproducibility(): """Test that results are reproducible with the same seed.""" - df1 = gen_competing_risks( - n=20, - n_risks=2, - seed=42 - ) - - df2 = gen_competing_risks( - n=20, - n_risks=2, - seed=42 - ) - + df1 = gen_competing_risks(n=20, n_risks=2, seed=42) + + df2 = gen_competing_risks(n=20, n_risks=2, seed=42) + pd.testing.assert_frame_equal(df1, df2) - + # Different seeds should produce different results - df3 = gen_competing_risks( - n=20, - n_risks=2, - seed=43 - ) - + df3 = gen_competing_risks(n=20, n_risks=2, seed=43) + with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df3) diff --git a/tests/test_cphm.py b/tests/test_cphm.py index 46c976a..f71d04c 100644 --- a/tests/test_cphm.py +++ b/tests/test_cphm.py @@ -2,48 +2,65 @@ Tests for the Cox Proportional Hazards Model (CPHM) generator. """ -import pytest import pandas as pd +import pytest + from gen_surv.cphm import gen_cphm def test_gen_cphm_output_shape(): """Test that the output DataFrame has the expected shape and columns.""" - df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + df = gen_cphm( + n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 + ) assert df.shape == (50, 3) assert list(df.columns) == ["time", "status", "X0"] def test_gen_cphm_status_range(): """Test that status values are binary (0 or 1).""" - df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covariate_range=1.5) + df = gen_cphm( + n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covariate_range=1.5 + ) assert df["status"].isin([0, 1]).all() def test_gen_cphm_time_positive(): """Test that all time values are positive.""" - df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + df = gen_cphm( + n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 + ) assert (df["time"] > 0).all() def test_gen_cphm_covariate_range(): """Test that covariate values are within the specified range.""" covar_max = 2.5 - df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=covar_max) + df = gen_cphm( + n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=covar_max + ) assert (df["X0"] >= 0).all() assert (df["X0"] <= covar_max).all() def test_gen_cphm_seed_reproducibility(): """Test that setting the same seed produces identical results.""" - df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) - df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) + df1 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) + df2 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) pd.testing.assert_frame_equal(df1, df2) def test_gen_cphm_different_seeds(): """Test that different seeds produce different results.""" - df1 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42) - df2 = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=43) + df1 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) + df2 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=43 + ) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df2) diff --git a/tests/test_export.py b/tests/test_export.py index ad8e1dc..49dc704 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,10 +1,19 @@ import os + import pandas as pd -from gen_surv import generate, export_dataset + +from gen_surv import export_dataset, generate def test_export_dataset_csv(tmp_path): - df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + df = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) out_file = tmp_path / "data.csv" export_dataset(df, str(out_file)) assert out_file.exists() @@ -13,10 +22,16 @@ def test_export_dataset_csv(tmp_path): def test_export_dataset_json(tmp_path): - df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + df = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) out_file = tmp_path / "data.json" export_dataset(df, str(out_file)) assert out_file.exists() loaded = pd.read_json(out_file, orient="table") pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) - diff --git a/tests/test_interface.py b/tests/test_interface.py index ad2dc3c..8be76b3 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,6 +1,7 @@ -from gen_surv import generate import pytest +from gen_surv import generate + def test_generate_tdcm_runs(): df = generate( diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 39655c0..0c14b2d 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -1,5 +1,6 @@ import pandas as pd -from gen_surv.mixture import gen_mixture_cure, cure_fraction_estimate + +from gen_surv.mixture import cure_fraction_estimate, gen_mixture_cure def test_gen_mixture_cure_runs(): diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py index f7f3217..9e44860 100644 --- a/tests/test_piecewise.py +++ b/tests/test_piecewise.py @@ -1,14 +1,12 @@ import pandas as pd import pytest + from gen_surv.piecewise import gen_piecewise_exponential def test_gen_piecewise_exponential_runs(): df = gen_piecewise_exponential( - n=10, - breakpoints=[1.0], - hazard_rates=[0.5, 1.0], - seed=42 + n=10, breakpoints=[1.0], hazard_rates=[0.5, 1.0], seed=42 ) assert isinstance(df, pd.DataFrame) assert len(df) == 10 @@ -18,8 +16,5 @@ def test_gen_piecewise_exponential_runs(): def test_piecewise_invalid_lengths(): with pytest.raises(ValueError): gen_piecewise_exponential( - n=5, - breakpoints=[1.0, 2.0], - hazard_rates=[0.5], - seed=42 + n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42 ) diff --git a/tests/test_summary.py b/tests/test_summary.py index 7e18bac..cf63caf 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -3,7 +3,14 @@ def test_summarize_survival_dataset_basic(): - df = generate(model="cphm", n=20, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + df = generate( + model="cphm", + n=20, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) summary = summarize_survival_dataset(df, verbose=False) assert isinstance(summary, dict) assert "dataset_info" in summary diff --git a/tests/test_tdcm.py b/tests/test_tdcm.py index 507b51f..d609236 100644 --- a/tests/test_tdcm.py +++ b/tests/test_tdcm.py @@ -1,10 +1,20 @@ -import sys import os +import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.tdcm import gen_tdcm + def test_gen_tdcm_shape(): - df = gen_tdcm(n=50, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], - model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0) + df = gen_tdcm( + n=50, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) assert df.shape[1] == 6 assert "tdcov" in df.columns diff --git a/tests/test_thmm.py b/tests/test_thmm.py index 5f616d7..907beff 100644 --- a/tests/test_thmm.py +++ b/tests/test_thmm.py @@ -1,9 +1,11 @@ -import sys import os +import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.thmm import gen_thmm + def test_gen_thmm_shape(): df = gen_thmm( n=50, diff --git a/tests/test_validate.py b/tests/test_validate.py index 54f10a8..5252556 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,4 +1,5 @@ import pytest + import gen_surv.validate as v diff --git a/tests/test_version.py b/tests/test_version.py index 6fb4a02..aff52a9 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,4 +1,5 @@ from importlib.metadata import version + from gen_surv import __version__ diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 89d7c3e..ad7c8be 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -3,7 +3,14 @@ def test_plot_survival_curve_runs(): - df = generate(model="cphm", n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) fig, ax = plot_survival_curve(df) assert fig is not None assert ax is not None From 1671a02eef7908c94f5b6d8cd8979b39407c1177 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 06:20:36 +0100 Subject: [PATCH 11/19] Add extensive summary tests (#45) --- tests/test_export.py | 20 +++++++++++ tests/test_piecewise.py | 16 +++++++++ tests/test_piecewise_functions.py | 51 +++++++++++++++++++++++++++ tests/test_summary_extra.py | 51 +++++++++++++++++++++++++++ tests/test_summary_more.py | 58 +++++++++++++++++++++++++++++++ tests/test_validate.py | 13 +++++++ 6 files changed, 209 insertions(+) create mode 100644 tests/test_piecewise_functions.py create mode 100644 tests/test_summary_extra.py create mode 100644 tests/test_summary_more.py diff --git a/tests/test_export.py b/tests/test_export.py index 49dc704..f89e94f 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,6 +1,7 @@ import os import pandas as pd +import pytest from gen_surv import export_dataset, generate @@ -35,3 +36,22 @@ def test_export_dataset_json(tmp_path): assert out_file.exists() loaded = pd.read_json(out_file, orient="table") pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) + + +def test_export_dataset_feather_and_invalid(tmp_path): + df = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + feather_file = tmp_path / "data.feather" + export_dataset(df, str(feather_file)) + assert feather_file.exists() + loaded = pd.read_feather(feather_file) + pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) + + with pytest.raises(ValueError): + export_dataset(df, str(tmp_path / "data.txt"), fmt="txt") diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py index 9e44860..e587af2 100644 --- a/tests/test_piecewise.py +++ b/tests/test_piecewise.py @@ -18,3 +18,19 @@ def test_piecewise_invalid_lengths(): gen_piecewise_exponential( n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42 ) + +def test_piecewise_invalid_hazard_and_breakpoints(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[2.0, 1.0], + hazard_rates=[0.5, 1.0, 1.5], + seed=42, + ) + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, -1.0], + seed=42, + ) diff --git a/tests/test_piecewise_functions.py b/tests/test_piecewise_functions.py new file mode 100644 index 0000000..648e416 --- /dev/null +++ b/tests/test_piecewise_functions.py @@ -0,0 +1,51 @@ +import numpy as np +from gen_surv.piecewise import piecewise_hazard_function, piecewise_survival_function + + +def test_piecewise_hazard_function_scalar_and_array(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Scalar values + assert piecewise_hazard_function(0.5, breakpoints, hazard_rates) == 0.5 + assert piecewise_hazard_function(1.5, breakpoints, hazard_rates) == 1.0 + assert piecewise_hazard_function(3.0, breakpoints, hazard_rates) == 1.5 + # Array values + arr = np.array([0.5, 1.5, 3.0]) + np.testing.assert_allclose( + piecewise_hazard_function(arr, breakpoints, hazard_rates), + np.array([0.5, 1.0, 1.5]), + ) + + +def test_piecewise_hazard_function_negative_time(): + """Hazard should be zero for negative times.""" + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + assert piecewise_hazard_function(-1.0, breakpoints, hazard_rates) == 0 + np.testing.assert_array_equal( + piecewise_hazard_function(np.array([-0.5, -2.0]), breakpoints, hazard_rates), + np.array([0.0, 0.0]), + ) + + +def test_piecewise_survival_function(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Known survival probabilities + expected = np.exp(-np.array([0.0, 0.25, 1.0, 3.0])) + times = np.array([0.0, 0.5, 1.5, 3.0]) + np.testing.assert_allclose( + piecewise_survival_function(times, breakpoints, hazard_rates), + expected, + ) + + +def test_piecewise_survival_function_scalar_and_negative(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Scalar output should be a float + val = piecewise_survival_function(1.5, breakpoints, hazard_rates) + assert isinstance(val, float) + assert np.isclose(val, np.exp(-1.0)) + # Negative times return survival of 1 + assert piecewise_survival_function(-2.0, breakpoints, hazard_rates) == 1 diff --git a/tests/test_summary_extra.py b/tests/test_summary_extra.py new file mode 100644 index 0000000..a908a5b --- /dev/null +++ b/tests/test_summary_extra.py @@ -0,0 +1,51 @@ +import pandas as pd +import pytest +from gen_surv.summary import check_survival_data_quality, compare_survival_datasets +from gen_surv import generate + + +def test_check_survival_data_quality_fix_issues(): + df = pd.DataFrame( + { + "time": [1.0, -0.5, None, 1.0], + "status": [1, 2, 0, 1], + "id": [1, 2, 3, 1], + } + ) + fixed, issues = check_survival_data_quality( + df, + id_col="id", + max_time=2.0, + fix_issues=True, + ) + assert issues["modifications"]["rows_dropped"] == 2 + assert issues["modifications"]["values_fixed"] == 1 + assert len(fixed) == 2 + + +def test_check_survival_data_quality_no_fix(): + """Issues should be reported but data left unchanged when fix_issues=False.""" + df = pd.DataFrame({"time": [-1.0, 2.0], "status": [3, 1]}) + checked, issues = check_survival_data_quality(df, max_time=1.0, fix_issues=False) + # Data is returned unmodified + pd.testing.assert_frame_equal(df, checked) + assert issues["invalid_values"]["negative_time"] == 1 + assert issues["invalid_values"]["excessive_time"] == 1 + assert issues["invalid_values"]["invalid_status"] == 1 + + +def test_compare_survival_datasets_basic(): + ds1 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + ds2 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=1.0, covariate_range=1.0) + comparison = compare_survival_datasets({"A": ds1, "B": ds2}) + assert set(["A", "B"]).issubset(comparison.columns) + assert "n_subjects" in comparison.index + + +def test_compare_survival_datasets_with_covariates_and_empty_error(): + ds = generate(model="cphm", n=3, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"]) + assert "only" in comparison.columns + assert "X0_mean" in comparison.index + with pytest.raises(ValueError): + compare_survival_datasets({}) diff --git a/tests/test_summary_more.py b/tests/test_summary_more.py new file mode 100644 index 0000000..be730da --- /dev/null +++ b/tests/test_summary_more.py @@ -0,0 +1,58 @@ +import pandas as pd +import pytest +from gen_surv.summary import ( + summarize_survival_dataset, + check_survival_data_quality, + _print_summary, +) + + +def test_summarize_survival_dataset_errors(): + df = pd.DataFrame({"time": [1, 2], "status": [1, 0]}) + # Missing time column + with pytest.raises(ValueError): + summarize_survival_dataset(df.drop(columns=["time"])) + # Missing ID column when specified + with pytest.raises(ValueError): + summarize_survival_dataset(df, id_col="id") + # Missing covariate columns + with pytest.raises(ValueError): + summarize_survival_dataset(df, covariate_cols=["bad"]) + + +def test_summarize_survival_dataset_verbose_output(capsys): + df = pd.DataFrame( + { + "time": [1.0, 2.0, 3.0], + "status": [1, 0, 1], + "id": [1, 2, 3], + "age": [30, 40, 50], + "group": ["A", "B", "A"], + } + ) + summary = summarize_survival_dataset( + df, id_col="id", covariate_cols=["age", "group"] + ) + _print_summary(summary, "time", "status", "id", ["age", "group"]) + captured = capsys.readouterr().out + assert "SURVIVAL DATASET SUMMARY" in captured + assert "age:" in captured + assert "Categorical" in captured + + +def test_check_survival_data_quality_duplicates_and_fix(): + df = pd.DataFrame( + { + "time": [1.0, -1.0, 2.0, 1.0], + "status": [1, 1, 0, 1], + "id": [1, 1, 2, 1], + } + ) + checked, issues = check_survival_data_quality(df, id_col="id", fix_issues=False) + assert issues["duplicates"]["duplicate_rows"] == 1 + assert issues["duplicates"]["duplicate_ids"] == 2 + fixed, issues_fixed = check_survival_data_quality( + df, id_col="id", max_time=2.0, fix_issues=True + ) + assert len(fixed) < len(df) + assert issues_fixed["modifications"]["rows_dropped"] > 0 diff --git a/tests/test_validate.py b/tests/test_validate.py index 5252556..a9cd10e 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -71,3 +71,16 @@ def test_validate_gen_aft_log_normal_inputs_valid(): def test_validate_dg_biv_inputs_valid_weibull(): """Valid parameters for a Weibull distribution should pass.""" v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0]) + + +def test_validate_gen_aft_weibull_inputs_and_log_logistic(): + with pytest.raises(ValueError): + v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0) + with pytest.raises(ValueError): + v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0) + + +def test_validate_competing_risks_inputs(): + with pytest.raises(ValueError): + v.validate_competing_risks_inputs(1, 2, [0.1], None, "uniform", 1.0) + v.validate_competing_risks_inputs(1, 1, [0.5], [[0.1]], "uniform", 0.5) From 903668656eb2f4fd6fc4cf76a8db57d7c5cd5fbc Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 08:59:19 +0100 Subject: [PATCH 12/19] Add visualization CLI tests (#47) --- pyproject.toml | 1 + tests/test_visualization.py | 106 +++++++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b30ffd6..6622899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ pandas = "^2.2.3" typer = "^0.12.3" matplotlib = "^3.10" lifelines = "^0.30" +pyarrow = "^14" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" diff --git a/tests/test_visualization.py b/tests/test_visualization.py index ad7c8be..59b2d04 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,5 +1,15 @@ +import pandas as pd +import pytest +import typer + from gen_surv import generate -from gen_surv.visualization import plot_survival_curve +from gen_surv.cli import visualize +from gen_surv.visualization import ( + describe_survival, + plot_covariate_effect, + plot_hazard_comparison, + plot_survival_curve, +) def test_plot_survival_curve_runs(): @@ -14,3 +24,97 @@ def test_plot_survival_curve_runs(): fig, ax = plot_survival_curve(df) assert fig is not None assert ax is not None + + +def test_plot_hazard_comparison_runs(): + df1 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + df2 = generate( + model="aft_weibull", + n=5, + beta=[0.5], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=1.0, + ) + models = {"cphm": df1, "aft_weibull": df2} + fig, ax = plot_hazard_comparison(models) + assert fig is not None + assert ax is not None + + +def test_plot_covariate_effect_runs(): + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + fig, ax = plot_covariate_effect(df, covariate_col="X0", n_groups=2) + assert fig is not None + assert ax is not None + + +def test_describe_survival_summary(): + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + summary = describe_survival(df) + expected_metrics = [ + "Total Observations", + "Number of Events", + "Number Censored", + "Event Rate", + "Median Survival Time", + "Min Time", + "Max Time", + "Mean Time", + ] + assert list(summary["Metric"]) == expected_metrics + assert summary.shape[0] == len(expected_metrics) + + +def test_cli_visualize(tmp_path, capsys): + df = pd.DataFrame({"time": [1, 2, 3], "status": [1, 0, 1]}) + csv_path = tmp_path / "d.csv" + df.to_csv(csv_path, index=False) + out_file = tmp_path / "out.png" + visualize( + str(csv_path), + time_col="time", + status_col="status", + group_col=None, + output=str(out_file), + ) + assert out_file.exists() + captured = capsys.readouterr() + assert "Plot saved to" in captured.out + + +def test_cli_visualize_missing_column(tmp_path, capsys): + df = pd.DataFrame({"time": [1, 2], "event": [1, 0]}) + csv_path = tmp_path / "bad.csv" + df.to_csv(csv_path, index=False) + with pytest.raises(typer.Exit): + visualize( + str(csv_path), + time_col="time", + status_col="status", + group_col=None, + ) + captured = capsys.readouterr() + assert "Status column 'status' not found in data" in captured.out From 533f118e39a00a59388a369a7403578ad9635d51 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 10:05:59 +0100 Subject: [PATCH 13/19] Expand validate tests (#49) --- tests/test_cli.py | 93 ++++++++++++++++++++++++++ tests/test_validate.py | 128 ++++++++++++++++++++++++++++++++++++ tests/test_visualization.py | 52 +++++++++++++++ 3 files changed, 273 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4b819c8..bb4537f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -53,3 +53,96 @@ def fake_generate(model: str, n: int): assert out_file.exists() content = out_file.read_text() assert "time,status,X0,X1" in content + + +def test_dataset_fallback(monkeypatch): + """If generate fails with additional kwargs, dataset retries with minimal args.""" + calls = [] + + def fake_generate(**kwargs): + calls.append(kwargs) + if len(calls) == 1: + raise TypeError("bad args") + return pd.DataFrame({"time": [0], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset(model="cphm", n=2, output=None) + # first call has many parameters, second only model and n + assert calls[-1] == {"model": "cphm", "n": 2} + assert len(calls) == 2 + + +def test_dataset_weibull_parameters(monkeypatch): + """Parameters for aft_weibull model are forwarded correctly.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [0]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset(model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None) + assert captured["model"] == "aft_weibull" + assert captured["beta"] == [0.1, 0.2] + assert captured["shape"] == 1.1 + assert captured["scale"] == 2.2 + + +def test_dataset_aft_ln(monkeypatch): + """aft_ln model should forward beta list and sigma.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset(model="aft_ln", n=1, beta=[0.3, 0.4], sigma=1.2, output=None) + assert captured["beta"] == [0.3, 0.4] + assert captured["sigma"] == 1.2 + + +def test_dataset_competing_risks(monkeypatch): + """competing_risks expands betas and passes hazards.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="competing_risks", + n=1, + n_risks=2, + baseline_hazards=[0.1, 0.2], + beta=0.5, + output=None, + ) + assert captured["n_risks"] == 2 + assert captured["baseline_hazards"] == [0.1, 0.2] + assert captured["betas"] == [0.5, 0.5] + + +def test_dataset_mixture_cure(monkeypatch): + """mixture_cure passes cure and baseline parameters.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="mixture_cure", + n=1, + cure_fraction=0.2, + baseline_hazard=0.1, + beta=[0.4], + output=None, + ) + assert captured["cure_fraction"] == 0.2 + assert captured["baseline_hazard"] == 0.1 + assert captured["betas_survival"] == [0.4] + assert captured["betas_cure"] == [0.4] + diff --git a/tests/test_validate.py b/tests/test_validate.py index a9cd10e..8fa6e9b 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -42,6 +42,31 @@ def test_validate_gen_cmm_inputs_invalid_beta_length(): ) +@pytest.mark.parametrize( + "n, model_cens, cens_par, cov_range, rate", + [ + (0, "uniform", 0.5, 1.0, [0.1] * 6), + (1, "bad", 0.5, 1.0, [0.1] * 6), + (1, "uniform", 0.0, 1.0, [0.1] * 6), + (1, "uniform", 0.5, 0.0, [0.1] * 6), + (1, "uniform", 0.5, 1.0, [0.1] * 3), + ], +) +def test_validate_gen_cmm_inputs_other_invalid( + n, model_cens, cens_par, cov_range, rate +): + with pytest.raises(ValueError): + v.validate_gen_cmm_inputs( + n, model_cens, cens_par, [0.1, 0.2, 0.3], cov_range, rate + ) + + +def test_validate_gen_cmm_inputs_valid(): + v.validate_gen_cmm_inputs( + 1, "uniform", 1.0, [0.1, 0.2, 0.3], covariate_range=1.0, rate=[0.1] * 6 + ) + + def test_validate_gen_tdcm_inputs_invalid_lambda(): """Lambda <= 0 should raise a ValueError.""" with pytest.raises(ValueError): @@ -57,6 +82,44 @@ def test_validate_gen_tdcm_inputs_invalid_lambda(): ) +@pytest.mark.parametrize( + "dist,corr,dist_par", + [ + ("bad", 0.5, [1, 2]), + ("weibull", 0.0, [1, 2, 3, 4]), + ("weibull", 0.5, [1, 2, -1, 2]), + ("weibull", 0.5, [1, 2, 3]), + ("exponential", 2.0, [1, 1]), + ("exponential", 0.5, [1]), + ], +) +def test_validate_gen_tdcm_inputs_invalid_dist(dist, corr, dist_par): + with pytest.raises(ValueError): + v.validate_gen_tdcm_inputs( + 1, + dist, + corr, + dist_par, + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) + + +def test_validate_gen_tdcm_inputs_valid(): + v.validate_gen_tdcm_inputs( + 1, + "weibull", + 0.5, + [1, 1, 1, 1], + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) + + def test_validate_gen_aft_log_normal_inputs_valid(): """Valid parameters should not raise an error for AFT log-normal.""" v.validate_gen_aft_log_normal_inputs( @@ -68,11 +131,37 @@ def test_validate_gen_aft_log_normal_inputs_valid(): ) +@pytest.mark.parametrize( + "n,beta,sigma,model_cens,cens_par", + [ + (0, [0.1], 1.0, "uniform", 1.0), + (1, "bad", 1.0, "uniform", 1.0), + (1, [0.1], 0.0, "uniform", 1.0), + (1, [0.1], 1.0, "bad", 1.0), + (1, [0.1], 1.0, "uniform", 0.0), + ], +) +def test_validate_gen_aft_log_normal_inputs_invalid( + n, beta, sigma, model_cens, cens_par +): + with pytest.raises(ValueError): + v.validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) + + def test_validate_dg_biv_inputs_valid_weibull(): """Valid parameters for a Weibull distribution should pass.""" v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0]) +def test_validate_dg_biv_inputs_invalid_corr_and_params(): + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "exponential", -2.0, [1.0, 1.0]) + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "exponential", 0.5, [1.0]) + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "weibull", 0.5, [1.0, 1.0]) + + def test_validate_gen_aft_weibull_inputs_and_log_logistic(): with pytest.raises(ValueError): v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0) @@ -80,7 +169,46 @@ def test_validate_gen_aft_weibull_inputs_and_log_logistic(): v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0) +@pytest.mark.parametrize( + "shape,scale", + [(-1.0, 1.0), (1.0, -1.0)], +) +def test_validate_gen_aft_weibull_invalid_params(shape, scale): + with pytest.raises(ValueError): + v.validate_gen_aft_weibull_inputs(1, [0.1], shape, scale, "uniform", 1.0) + + +def test_validate_gen_aft_weibull_valid(): + v.validate_gen_aft_weibull_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) + + +def test_validate_gen_aft_log_logistic_valid(): + v.validate_gen_aft_log_logistic_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) + + def test_validate_competing_risks_inputs(): with pytest.raises(ValueError): v.validate_competing_risks_inputs(1, 2, [0.1], None, "uniform", 1.0) v.validate_competing_risks_inputs(1, 1, [0.5], [[0.1]], "uniform", 0.5) + + +@pytest.mark.parametrize( + "n,model_cens,cens_par,beta,cov_range,rate", + [ + (0, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "bad", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 0.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2, 0.3], 0.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1]), + ], +) +def test_validate_gen_thmm_inputs_invalid( + n, model_cens, cens_par, beta, cov_range, rate +): + with pytest.raises(ValueError): + v.validate_gen_thmm_inputs(n, model_cens, cens_par, beta, cov_range, rate) + + +def test_validate_gen_thmm_inputs_valid(): + v.validate_gen_thmm_inputs(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 59b2d04..f077e27 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -118,3 +118,55 @@ def test_cli_visualize_missing_column(tmp_path, capsys): ) captured = capsys.readouterr() assert "Status column 'status' not found in data" in captured.out + + +def test_cli_visualize_missing_time(tmp_path, capsys): + df = pd.DataFrame({"t": [1, 2], "status": [1, 0]}) + path = tmp_path / "d.csv" + df.to_csv(path, index=False) + with pytest.raises(typer.Exit): + visualize(str(path), time_col="time", status_col="status") + captured = capsys.readouterr() + assert "Time column 'time' not found in data" in captured.out + + +def test_cli_visualize_missing_group(tmp_path, capsys): + df = pd.DataFrame({"time": [1], "status": [1], "x": [0]}) + path = tmp_path / "d2.csv" + df.to_csv(path, index=False) + with pytest.raises(typer.Exit): + visualize(str(path), time_col="time", status_col="status", group_col="group") + captured = capsys.readouterr() + assert "Group column 'group' not found in data" in captured.out + + +def test_cli_visualize_import_error(monkeypatch, tmp_path, capsys): + """visualize exits when matplotlib is missing.""" + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("matplotlib"): # simulate missing dependency + raise ImportError("no matplot") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + csv_path = tmp_path / "d.csv" + pd.DataFrame({"time": [1], "status": [1]}).to_csv(csv_path, index=False) + with pytest.raises(typer.Exit): + visualize(str(csv_path)) + captured = capsys.readouterr() + assert "Visualization requires matplotlib" in captured.out + + +def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys): + """visualize handles CSV read failures gracefully.""" + monkeypatch.setattr("pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom"))) + csv_path = tmp_path / "x.csv" + csv_path.write_text("time,status\n1,1\n") + with pytest.raises(typer.Exit): + visualize(str(csv_path)) + captured = capsys.readouterr() + assert "Error loading CSV file" in captured.out + From 9106058eb08145958fb9f3c1c2bef6870fe8192e Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 13:24:05 +0100 Subject: [PATCH 14/19] docs: rewrite README (#51) --- .github/workflows/docs.yml | 27 ++++ .readthedocs.yml | 14 +- README.md | 204 ++++++++------------------- docs/requirements.txt | 8 +- docs/source/_static/custom.css | 56 ++++++++ docs/source/algorithms.md | 44 ++++++ docs/source/api/index.md | 88 ++++++++++++ docs/source/bibliography.md | 22 +++ docs/source/changelog.md | 5 + docs/source/conf.py | 92 +++++++++--- docs/source/contributing.md | 5 + docs/source/examples/index.md | 8 ++ docs/source/getting_started.md | 77 ++++++++++ docs/source/index.md | 175 ++++++++++++++--------- docs/source/rtd.md | 16 +++ docs/source/theory.md | 39 +++-- docs/source/tutorials/basic_usage.md | 110 +++++++++++++++ docs/source/tutorials/index.md | 9 ++ docs/source/usage.md | 12 ++ pyproject.toml | 7 +- tests/test_mixture.py | 67 +++++++++ 21 files changed, 831 insertions(+), 254 deletions(-) create mode 100644 .github/workflows/docs.yml create mode 100644 docs/source/_static/custom.css create mode 100644 docs/source/algorithms.md create mode 100644 docs/source/api/index.md create mode 100644 docs/source/bibliography.md create mode 100644 docs/source/changelog.md create mode 100644 docs/source/contributing.md create mode 100644 docs/source/examples/index.md create mode 100644 docs/source/getting_started.md create mode 100644 docs/source/rtd.md create mode 100644 docs/source/tutorials/basic_usage.md create mode 100644 docs/source/tutorials/index.md diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..a7a2501 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,27 @@ +name: Documentation Build + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install Poetry + run: pip install poetry + - name: Install dependencies + run: poetry install --with docs + - name: Build documentation + run: poetry run sphinx-build -W -b html docs/source docs/build + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + name: documentation + path: docs/build/ diff --git a/.readthedocs.yml b/.readthedocs.yml index a811ac3..a1f8a6e 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,10 +4,22 @@ build: os: ubuntu-22.04 tools: python: "3.11" + jobs: + post_create_environment: + - pip install poetry + post_install: + - poetry config virtualenvs.create false + - poetry install --with docs python: install: - - requirements: docs/requirements.txt + - method: pip + path: . sphinx: configuration: docs/source/conf.py + fail_on_warning: false + +formats: + - pdf + - epub diff --git a/README.md b/README.md index 3a5cb69..4ebbdb0 100644 --- a/README.md +++ b/README.md @@ -1,180 +1,97 @@ # gen_surv -![Coverage](https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg) -[![Docs](https://readthedocs.org/projects/gensurvpy/badge/?version=stable)](https://gensurvpy.readthedocs.io/en/stable/) -![PyPI](https://img.shields.io/pypi/v/gen_surv) -![Tests](https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/test.yml/badge.svg) -![Python](https://img.shields.io/pypi/pyversions/gen_surv) +[![Coverage](https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg)](https://app.codecov.io/gh/DiogoRibeiro7/genSurvPy) +[![Docs](https://readthedocs.org/projects/gensurvpy/badge/?version=latest)](https://gensurvpy.readthedocs.io/en/latest/) +[![PyPI](https://img.shields.io/pypi/v/gen_surv)](https://pypi.org/project/gen-surv/) +[![Tests](https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/test.yml/badge.svg)](https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/test.yml) +[![Python](https://img.shields.io/pypi/pyversions/gen_surv)](https://pypi.org/project/gen-surv/) - -**gen_surv** is a Python package for simulating survival data under a variety of models, inspired by the R package [`genSurv`](https://cran.r-project.org/package=genSurv). It supports data generation for: - -- Cox Proportional Hazards Models (CPHM) -- Continuous-Time Markov Models (CMM) -- Time-Dependent Covariate Models (TDCM) -- Time-Homogeneous Hidden Markov Models (THMM) +**gen_surv** is a Python package for simulating survival data under a variety of statistical models. It is inspired by the R package [genSurv](https://cran.r-project.org/package=genSurv) and provides a unified interface for generating realistic survival datasets. --- -## 📦 Installation +## Features -```bash -poetry install -``` -This package requires **Python 3.10** or later. -## ✨ Features - -- Consistent interface across models -- Censoring support (`uniform` or `exponential`) -- Easy integration with `pandas` and `NumPy` -- Suitable for benchmarking survival algorithms and teaching -- Accelerated Failure Time (Log-Normal) model generator +- Cox proportional hazards model (CPHM) +- Accelerated failure time models (log-normal, log-logistic) +- Continuous-time multi-state Markov model (CMM) +- Time-dependent covariate model (TDCM) +- Time-homogeneous hidden Markov model (THMM) - Mixture cure and piecewise exponential models - Competing risks generators (constant and Weibull hazards) -- Command-line interface powered by `Typer` -- Export utilities for CSV, JSON, and Feather formats +- Command-line interface and export utilities -## 🧪 Example +## Installation -```python -from gen_surv import generate +Install the latest release from PyPI: -# CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) - -# AFT Log-Normal -generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) - -# CMM -generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, - qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) - -# TDCM -generate(model="tdcm", n=100, dist="weibull", corr=0.5, - dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], lam=1.0) - -# THMM -generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], - emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, - p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) - -# Mixture Cure -generate(model="mixture_cure", n=100, cure_fraction=0.3, seed=42) - -# Piecewise Exponential -generate( - model="piecewise_exponential", - n=100, - breakpoints=[1.0, 3.0], - hazard_rates=[0.2, 0.5, 1.0], - seed=0 -) +```bash +pip install gen-surv ``` -## ⌨️ Command-Line Usage - -Install the package and use ``python -m gen_surv`` to generate datasets without -writing Python code: +To develop locally with all extras: ```bash -python -m gen_surv dataset aft_ln --n 100 > data.csv +git clone https://github.com/DiogoRibeiro7/genSurvPy.git +cd genSurvPy +poetry install ``` -## 🔧 API Overview - -| Function | Description | -|----------|-------------| -| `generate()` | Unified interface that calls any generator | -| `gen_cphm()` | Cox Proportional Hazards Model | -| `gen_cmm()` | Continuous-Time Multi-State Markov Model | -| `gen_tdcm()` | Time-Dependent Covariate Model | -| `gen_thmm()` | Time-Homogeneous Markov Model | -| `gen_aft_log_normal()` | Accelerated Failure Time Log-Normal | -| `gen_aft_log_logistic()` | AFT model with log-logistic baseline | -| `gen_competing_risks()` | Competing risks with constant hazards | -| `gen_competing_risks_weibull()` | Competing risks with Weibull hazards | -| `gen_mixture_cure()` | Mixture cure model | -| `cure_fraction_estimate()` | Estimate cure fraction | -| `gen_piecewise_exponential()` | Piecewise exponential model | -| `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times | -| `runifcens()` | Generate uniform censoring times | -| `rexpocens()` | Generate exponential censoring times | -| `export_dataset()` | Save a dataset to CSV, JSON or Feather | - - -```text -genSurvPy/ -├── gen_surv/ # Pacote principal -│ ├── __main__.py # Interface CLI via python -m -│ ├── cphm.py -│ ├── cmm.py -│ ├── tdcm.py -│ ├── thmm.py -│ ├── censoring.py -│ ├── bivariate.py -│ ├── validate.py -│ └── interface.py -├── tests/ # Testes automatizados -│ ├── test_cphm.py -│ ├── test_cmm.py -│ ├── test_tdcm.py -│ ├── test_thmm.py -├── examples/ # Exemplos de uso -│ ├── run_aft.py -│ ├── run_cmm.py -│ ├── run_cphm.py -│ ├── run_tdcm.py -│ └── run_thmm.py -├── docs/ # Documentação Sphinx -│ ├── source/ -│ └── ... -├── scripts/ # Utilidades diversas -│ └── check_version_match.py -├── tasks.py # Tarefas automatizadas com Invoke -├── TODO.md # Roadmap de desenvolvimento -├── pyproject.toml # Configurado com Poetry -├── README.md -├── LICENCE -└── .gitignore -``` +## Quick Example -## 🧠 License +```python +from gen_surv import generate -MIT License. See [LICENCE](LICENCE) for details. +# basic Cox proportional hazards data +sim = generate(model="cphm", n=100, beta=0.5, covar=2.0, + model_cens="uniform", cens_par=1.0) +``` +See the [usage guide](docs/source/getting_started.md) for more examples. -## 🔖 Release Process +## Supported Models -This project uses Git tags to manage releases. A GitHub Actions workflow -(`version-check.yml`) verifies that the version declared in `pyproject.toml` -matches the latest Git tag. If they diverge, the workflow fails and prompts a -correction before merging. Run `python scripts/check_version_match.py` locally -before creating a tag to catch issues early. +| Model | Description | +|----------------------|-----------------------------------------| +| **CPHM** | Cox proportional hazards | +| **AFT** | Accelerated failure time (log-normal, log-logistic) | +| **CMM** | Continuous-time multi-state Markov | +| **TDCM** | Time-dependent covariates | +| **THMM** | Time-homogeneous hidden Markov | +| **Competing Risks** | Multiple event types with cause-specific hazards | +| **Mixture Cure** | Models long-term survivors | +| **Piecewise Exponential** | Flexible baseline hazard via intervals | -## 🌟 Code of Conduct +More details on each algorithm are available in the [Algorithms](docs/source/algorithms.md) page and the [theory guide](docs/source/theory.md). -Please read our [Code of Conduct](CODE_OF_CONDUCT.md) to learn about the -expectations for participants in this project. +## Command-Line Usage -## 🤝 Contributing +Datasets can be generated without writing Python code: + +```bash +python -m gen_surv dataset cphm --n 1000 -o survival.csv +``` -Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on setting up your environment, running tests, and submitting pull requests. +## Documentation -## 🔧 Development Tasks +Full documentation is hosted on [Read the Docs](https://gensurvpy.readthedocs.io/en/latest/). It includes installation instructions, tutorials, API references and a bibliography. -Common project commands are defined in [`tasks.py`](tasks.py) and can be executed with [Invoke](https://www.pyinvoke.org/): +To build the docs locally: ```bash -poetry run inv -l # list available tasks -poetry run inv test # run the test suite +cd docs +make html ``` -## 📑 Citation +Open `build/html/index.html` in your browser to view the result. + +## License + +This project is licensed under the MIT License. See [LICENCE](LICENCE) for details. -If you use **gen_surv** in your work, please cite it using the metadata in -[`CITATION.cff`](CITATION.cff). Many reference managers can import this file -directly. +## Citation + +If you use **gen_surv** in your research, please cite the project using the metadata in [CITATION.cff](CITATION.cff). ## Author @@ -183,3 +100,4 @@ directly. - ORCID: - Professional email: - Personal email: + diff --git a/docs/requirements.txt b/docs/requirements.txt index 98a3c62..5cf2520 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,6 @@ -sphinx -myst-parser +sphinx>=6.0 +myst-parser>=1.0.0,<4.0.0 +sphinx-rtd-theme +sphinx-autodoc-typehints +sphinx-copybutton +sphinx-design diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css new file mode 100644 index 0000000..22f503a --- /dev/null +++ b/docs/source/_static/custom.css @@ -0,0 +1,56 @@ +/* Custom styling for gen_surv documentation */ + +/* Improve code block appearance */ +.highlight { + background: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 4px; + padding: 10px; + margin: 10px 0; +} + +/* Style admonitions */ +.admonition { + border-radius: 6px; + margin: 1em 0; + padding: 1em; +} + +.admonition.tip { + border-left: 4px solid #28a745; + background-color: #d4edda; +} + +.admonition.note { + border-left: 4px solid #007bff; + background-color: #cce5ff; +} + +.admonition.warning { + border-left: 4px solid #ffc107; + background-color: #fff3cd; +} + +/* Table styling */ +table.docutils { + border-collapse: collapse; + margin: 1em 0; + width: 100%; +} + +table.docutils th, +table.docutils td { + border: 1px solid #ddd; + padding: 8px; + text-align: left; +} + +table.docutils th { + background-color: #f8f9fa; + font-weight: bold; +} + +/* Math styling */ +.math { + font-size: 1.1em; +} diff --git a/docs/source/algorithms.md b/docs/source/algorithms.md new file mode 100644 index 0000000..30ed1fb --- /dev/null +++ b/docs/source/algorithms.md @@ -0,0 +1,44 @@ +# Algorithm Overview + +This page provides a short description of each model implemented in **gen_surv**. For mathematical details see {doc}`theory`. + +## Cox Proportional Hazards Model (CPHM) +The hazard at time $t$ is proportional to a baseline hazard multiplied by the exponential of covariate effects. +It is widely used for modelling relative risks under the proportional hazards assumption. +See {ref}`Cox1972` in the {doc}`bibliography` for the seminal paper. + +## Accelerated Failure Time Models (AFT) +These parametric models directly relate covariates to survival time. +gen_surv includes log-normal, log-logistic and Weibull variants allowing different baseline distributions. +They are convenient when the effect of covariates accelerates or decelerates event times. + +## Continuous-Time Multi-State Markov Model (CMM) +Transitions between states are governed by a generator matrix. +This model is suited for illness-death and other multi-state processes where state occupancy changes continuously over time. +The mathematical formulation follows the counting-process approach of Andersen et al. {ref}`Andersen1993`. + +## Time-Dependent Covariate Model (TDCM) +Extends the Cox model to covariates that vary during follow-up. +Covariates are simulated in a piecewise fashion with optional correlation across segments. + +## Time-Homogeneous Hidden Markov Model (THMM) +Handles processes with unobserved states that emit observable values. +The latent transitions follow a homogeneous Markov chain while emissions are Gaussian. +For background on these models see Zucchini et al. {ref}`Zucchini2017`. + +## Competing Risks +Allows multiple failure types with cause-specific hazards. +gen_surv supports constant and Weibull hazards for each cause. +The subdistribution approach of Fine and Gray {ref}`FineGray1999` is commonly used for analysis. + +## Mixture Cure Model +Assumes a proportion of individuals will never experience the event. +A logistic component determines who is cured, while uncured subjects follow an exponential failure distribution. +Mixture cure models were introduced by Farewell {ref}`Farewell1982`. + +## Piecewise Exponential Model +Approximates complex hazard shapes by dividing follow-up time into intervals with constant hazard within each interval. +This yields a flexible baseline hazard while remaining computationally simple. + +For additional reading on these methods please see the {doc}`bibliography`. + diff --git a/docs/source/api/index.md b/docs/source/api/index.md new file mode 100644 index 0000000..0a26c5a --- /dev/null +++ b/docs/source/api/index.md @@ -0,0 +1,88 @@ +# API Reference + +Complete documentation for all gen_surv functions and classes. + +## Core Interface + +```{eval-rst} +.. automodule:: gen_surv.interface + :members: + :undoc-members: + :show-inheritance: +``` + +## Model Generators + +### Cox Proportional Hazards Model +```{eval-rst} +.. automodule:: gen_surv.cphm + :members: + :undoc-members: + :show-inheritance: +``` + +### Accelerated Failure Time Models +```{eval-rst} +.. automodule:: gen_surv.aft + :members: + :undoc-members: + :show-inheritance: +``` + +### Continuous-Time Markov Models +```{eval-rst} +.. automodule:: gen_surv.cmm + :members: + :undoc-members: + :show-inheritance: +``` + +### Time-Dependent Covariate Models +```{eval-rst} +.. automodule:: gen_surv.tdcm + :members: + :undoc-members: + :show-inheritance: +``` + +### Time-Homogeneous Markov Models +```{eval-rst} +.. automodule:: gen_surv.thmm + :members: + :undoc-members: + :show-inheritance: +``` + +## Utility Functions + +### Censoring Functions +```{eval-rst} +.. automodule:: gen_surv.censoring + :members: + :undoc-members: + :show-inheritance: +``` + +### Bivariate Distributions +```{eval-rst} +.. automodule:: gen_surv.bivariate + :members: + :undoc-members: + :show-inheritance: +``` + +### Validation Functions +```{eval-rst} +.. automodule:: gen_surv.validate + :members: + :undoc-members: + :show-inheritance: +``` + +### Command Line Interface +```{eval-rst} +.. automodule:: gen_surv.cli + :members: + :undoc-members: + :show-inheritance: +``` diff --git a/docs/source/bibliography.md b/docs/source/bibliography.md new file mode 100644 index 0000000..c3ae2c9 --- /dev/null +++ b/docs/source/bibliography.md @@ -0,0 +1,22 @@ +# References + +Below is a selection of references covering the statistical models implemented in **gen_surv**. + +.. _Cox1972: +Cox, D. R. (1972). Regression Models and Life-Tables. *Journal of the Royal Statistical Society: Series B*, 34(2), 187-220. + +.. _Farewell1982: +Farewell, V.T. (1982). The Use of Mixture Models for the Analysis of Survival Data with Long-Term Survivors. *Biometrics*, 38(4), 1041-1046. + +.. _FineGray1999: +Fine, J.P., & Gray, R.J. (1999). A Proportional Hazards Model for the Subdistribution of a Competing Risk. *Journal of the American Statistical Association*, 94(446), 496-509. + +.. _Andersen1993: +Andersen, P.K., Borgan, Ø., Gill, R.D., & Keiding, N. (1993). *Statistical Models Based on Counting Processes*. Springer. + +.. _Zucchini2017: +Zucchini, W., MacDonald, I.L., & Langrock, R. (2017). *Hidden Markov Models for Time Series*. Chapman and Hall/CRC. + +- Klein, J.P., & Moeschberger, M.L. (2003). *Survival Analysis: Techniques for Censored and Truncated Data*. Springer. +- Kalbfleisch, J.D., & Prentice, R.L. (2002). *The Statistical Analysis of Failure Time Data*. Wiley. +- Cook, R.J., & Lawless, J.F. (2007). *The Statistical Analysis of Recurrent Events*. Springer. diff --git a/docs/source/changelog.md b/docs/source/changelog.md new file mode 100644 index 0000000..8019063 --- /dev/null +++ b/docs/source/changelog.md @@ -0,0 +1,5 @@ +# Changelog + +```{include} ../../CHANGELOG.md +:relative: true +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 39fa69c..22a56b1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,42 +1,92 @@ -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os import sys -sys.path.insert(0, os.path.abspath('../../gen_surv')) +from pathlib import Path + +# Add the package to the Python path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root / "gen_surv")) +# Project information project = 'gen_surv' copyright = '2025, Diogo Ribeiro' author = 'Diogo Ribeiro' release = '1.0.8' +version = '1.0.8' -# -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration - +# General configuration extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", - "myst_parser", "sphinx.ext.viewcode", - "sphinx.ext.autosectionlabel", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.githubpages", + "myst_parser", + "sphinx_copybutton", + "sphinx_design", +] + +# MyST Parser configuration +myst_enable_extensions = [ + "colon_fence", + "deflist", + "html_admonition", + "html_image", + "linkify", + "replacements", + "smartquotes", + "substitution", + "tasklist", ] -autosectionlabel_prefix_document = True +# Autodoc configuration +autodoc_default_options = { + 'members': True, + 'member-order': 'bysource', + 'special-members': '__init__', + 'undoc-members': True, + 'exclude-members': '__weakref__' +} + +# Autosummary +autosummary_generate = True -# Point to index.md or index.rst as the root document -master_doc = "index" +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False -templates_path = ['_templates'] -exclude_patterns = [] +# Intersphinx mapping +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'pandas': ('https://pandas.pydata.org/docs/', None), +} -# -- Options for HTML output ------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +# HTML theme options +html_theme = 'sphinx_rtd_theme' +html_theme_options = { + 'canonical_url': 'https://gensurvpy.readthedocs.io/', + 'analytics_id': '', + 'logo_only': False, + 'display_version': True, + 'prev_next_buttons_location': 'bottom', + 'style_external_links': False, + 'style_nav_header_background': '#2980B9', + 'collapse_navigation': True, + 'sticky_navigation': True, + 'navigation_depth': 4, + 'includehidden': True, + 'titles_only': False +} -html_theme = 'alabaster' html_static_path = ['_static'] +html_css_files = ['custom.css'] +# Output file base name for HTML help builder +htmlhelp_basename = 'gensurvdoc' +# Copy button configuration +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_is_regexp = True diff --git a/docs/source/contributing.md b/docs/source/contributing.md new file mode 100644 index 0000000..37fab23 --- /dev/null +++ b/docs/source/contributing.md @@ -0,0 +1,5 @@ +# Contributing + +```{include} ../../CONTRIBUTING.md +:relative: true +``` diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md new file mode 100644 index 0000000..0f58536 --- /dev/null +++ b/docs/source/examples/index.md @@ -0,0 +1,8 @@ +# Examples + +Real-world examples and use cases for gen_surv. + +```{toctree} +:maxdepth: 2 + +``` diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md new file mode 100644 index 0000000..b94f4b4 --- /dev/null +++ b/docs/source/getting_started.md @@ -0,0 +1,77 @@ +# Getting Started + +This guide will help you install gen_surv and generate your first survival dataset. + +## Installation + +### From PyPI (Recommended) + +```bash +pip install gen-surv +``` + +### From Source + +```bash +git clone https://github.com/DiogoRibeiro7/genSurvPy.git +cd genSurvPy +poetry install +``` + +## Basic Usage + +The main entry point is the `generate()` function: + +```python +from gen_surv import generate + +# Generate Cox proportional hazards data + df = generate( + model="cphm", # Model type + n=100, # Sample size + beta=0.5, # Covariate effect + covar=2.0, # Covariate range + model_cens="uniform", # Censoring type + cens_par=3.0 # Censoring parameter + ) + +print(df.head()) +``` + +## Understanding the Output + +All models return a pandas DataFrame with at least these columns: + +- `time`: Observed event or censoring time +- `status`: Event indicator (1 = event, 0 = censored) +- Additional columns depend on the specific model + +## Command Line Usage + +Generate datasets directly from the terminal: + +```bash +# Generate CPHM data and save to CSV +python -m gen_surv dataset cphm --n 1000 -o survival_data.csv + +# Print AFT data to stdout +python -m gen_surv dataset aft_ln --n 500 +``` + +## Next Steps + +- Explore the {doc}`tutorials/index` for detailed examples +- Check the {doc}`api/index` for complete function documentation +- Read about the {doc}`theory` behind each model + +## Building the Documentation + +To preview the documentation locally run: + +```bash +cd docs +make html +``` + +More details about our Read the Docs configuration can be found in {doc}`rtd`. + diff --git a/docs/source/index.md b/docs/source/index.md index 5756e5b..9d52cb7 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,100 +1,135 @@ -# gen_surv +# gen_surv: Survival Data Simulation in Python -**gen_surv** is a Python package for simulating survival data under various models, inspired by the R package `genSurv`. +[![Documentation Status](https://readthedocs.org/projects/gensurvpy/badge/?version=latest)](https://gensurvpy.readthedocs.io/en/latest/?badge=latest) +[![PyPI version](https://badge.fury.io/py/gen-surv.svg)](https://badge.fury.io/py/gen-surv) +[![Python versions](https://img.shields.io/pypi/pyversions/gen-surv.svg)](https://pypi.org/project/gen-surv/) -It includes generators for: +**gen_surv** is a comprehensive Python package for simulating survival data under various statistical models, inspired by the R package `genSurv`. It provides a unified interface for generating synthetic survival datasets that are essential for: -- **Cox Proportional Hazards Models (CPHM)** -- **Continuous-Time Markov Models (CMM)** -- **Time-Dependent Covariate Models (TDCM)** -- **Time-Homogeneous Hidden Markov Models (THMM)** -- **Accelerated Failure Time (AFT) Log-Normal Models** +- **Research**: Testing new survival analysis methods +- **Education**: Teaching survival analysis concepts +- **Benchmarking**: Comparing different survival models +- **Validation**: Testing statistical software implementations -Key functions include `generate()`, `gen_cphm()`, `gen_cmm()`, `gen_tdcm()`, -`gen_thmm()`, `gen_aft_log_normal()`, `sample_bivariate_distribution()`, -`runifcens()`, and `rexpocens()`. +```{admonition} Quick Start +:class: tip ---- - -See the [Getting Started](usage) guide for installation instructions. - -## 📚 Modules - -```{toctree} -:maxdepth: 2 -:caption: Contents - -usage -modules -theory +Install with pip: +```bash +pip install gen-surv ``` - -# 🚀 Usage Example - +Generate your first dataset: ```python from gen_surv import generate +df = generate(model="cphm", n=100, beta=0.5, covar=2.0) +``` +``` -# CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) +## Supported Models -# AFT Log-Normal -generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) +| Model | Description | Use Case | +|-------|-------------|----------| +| **CPHM** | Cox Proportional Hazards | Standard survival regression | +| **AFT** | Accelerated Failure Time | Non-proportional hazards | +| **CMM** | Continuous-Time Markov | Multi-state processes | +| **TDCM** | Time-Dependent Covariates | Dynamic risk factors | +| **THMM** | Time-Homogeneous Markov | Hidden state processes | +| **Competing Risks** | Multiple event types | Cause-specific hazards | +| **Mixture Cure** | Long-term survivors | Logistic cure fraction | +| **Piecewise Exponential** | Piecewise constant hazard | Flexible baseline | -# CMM -generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, - qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) +## Algorithm Descriptions -# TDCM -generate(model="tdcm", n=100, dist="weibull", corr=0.5, - dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], lam=1.0) +For a brief summary of each statistical model see {doc}`algorithms`. Mathematical +details and notation are provided on the {doc}`theory` page. -# THMM -generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], - emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, - p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) -``` +## Documentation Contents -## ⌨️ Command-Line Usage +```{toctree} +:maxdepth: 2 -Generate datasets directly from the terminal: +getting_started +tutorials/index +api/index +theory +algorithms +examples/index +rtd +contributing +changelog +bibliography +``` -```bash -python -m gen_surv dataset aft_ln --n 100 > data.csv +## Quick Examples + +### Cox Proportional Hazards Model +```python +import gen_surv as gs + +# Basic CPHM with uniform censoring +df = gs.generate( + model="cphm", + n=500, + beta=0.5, + covar=2.0, + model_cens="uniform", + cens_par=3.0 +) ``` -## Repository Layout - -```text -genSurvPy/ -├── gen_surv/ -│ └── ... -├── tests/ -├── examples/ -├── docs/ -├── scripts/ -├── tasks.py -└── TODO.md +### Accelerated Failure Time Model +```python +# AFT with log-normal distribution +df = gs.generate( + model="aft_ln", + n=200, + beta=[0.5, -0.3, 0.2], + sigma=1.0, + model_cens="exponential", + cens_par=2.0 +) ``` -## 🔗 Project Links +### Multi-State Markov Model +```python +# Three-state illness-death model +df = gs.generate( + model="cmm", + n=300, + qmat=[[0, 0.1], [0.05, 0]], + p0=[1.0, 0.0], + model_cens="uniform", + cens_par=5.0 +) +``` -- [Source Code](https://github.com/DiogoRibeiro7/genSurvPy) -- [License](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENCE) -- [Code of Conduct](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/CODE_OF_CONDUCT.md) +## Key Features +- **Unified Interface**: Single `generate()` function for all models +- **Flexible Censoring**: Support for uniform and exponential censoring +- **Rich Parameterization**: Extensive customization options +- **Command-Line Interface**: Generate datasets from terminal +- **Comprehensive Validation**: Input parameter checking +- **Educational Focus**: Clear mathematical documentation ## Citation -If you use **gen_surv** in your work, please cite it using the metadata in -[CITATION.cff](../../CITATION.cff). +If you use gen_surv in your research, please cite: -## Author +```bibtex +@software{ribeiro2025gensurvpy, + title = {gen_surv: Survival Data Simulation in Python}, + author = {Diogo Ribeiro}, + year = {2025}, + url = {https://github.com/DiogoRibeiro7/genSurvPy}, + version = {1.0.8} +} +``` -**Diogo Ribeiro** — [ESMAD - Instituto Politécnico do Porto](https://esmad.ipp.pt) +## License -- ORCID: -- Professional email: -- Personal email: +MIT License - see [LICENSE](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENCE) for details. +For foundational papers related to these models see the {doc}`bibliography`. +Information on building the docs is provided in the {doc}`rtd` page. diff --git a/docs/source/rtd.md b/docs/source/rtd.md new file mode 100644 index 0000000..dc77463 --- /dev/null +++ b/docs/source/rtd.md @@ -0,0 +1,16 @@ +# Read the Docs + +This project uses [Read the Docs](https://readthedocs.org/) to host its documentation. The site is automatically rebuilt whenever changes are pushed to the repository. + +Our build configuration is defined in `.readthedocs.yml`. It installs the package with the `docs` dependency group and builds the Sphinx docs using Python 3.11. + +## Building Locally + +To preview the documentation on your machine, run: + +```bash +cd docs +make html +``` + +Open `build/html/index.html` in your browser to view the result. diff --git a/docs/source/theory.md b/docs/source/theory.md index 70a402f..5d58bb1 100644 --- a/docs/source/theory.md +++ b/docs/source/theory.md @@ -6,7 +6,9 @@ This page presents the mathematical formulation behind the survival models imple ## 1. Cox Proportional Hazards Model (CPHM) -The hazard function conditioned on covariates is: +This semi-parametric approach models the hazard as a baseline component +multiplied by an exponential term involving the covariates. The hazard +function conditioned on covariates is: $$ h(t \mid X) = h_0(t) \exp(X \\beta) @@ -39,7 +41,8 @@ $$ ## 2. Time-Dependent Covariate Model (TDCM) -A generalization of CPHM where covariates change over time: +This extension of the Cox model allows covariate values to vary during +follow-up, accommodating exposures or treatments that change over time: $$ h(t \mid Z(t)) = h_0(t) \\exp(Z(t) \\beta) @@ -51,7 +54,8 @@ In this package, piecewise covariate values are simulated with dependence across ## 3. Continuous-Time Multi-State Markov Model (CMM) -Markov model with generator matrix \( Q \). The transition probability matrix is given by: +This framework captures transitions between a finite set of states where +waiting times are exponentially distributed. With generator matrix \( Q \), the transition probability matrix is given by: $$ P(t) = \\exp(Qt) @@ -66,7 +70,9 @@ Where: ## 4. Time-Homogeneous Hidden Markov Model (THMM) -This model simulates observed states with unobserved latent state transitions. +This model handles situations where the process evolves through unobserved +states that generate the observed responses. It simulates observed states with +latent transitions. Let: @@ -84,7 +90,9 @@ $$ ## 5. Accelerated Failure Time (AFT) Models -AFT models assume that the effect of covariates accelerates or decelerates time to event directly, rather than the hazard. +These fully parametric models relate covariates to the logarithm of the +survival time. They assume the effect of a covariate speeds up or slows down the +event time directly, rather than acting on the hazard. ### Log-Normal AFT @@ -121,19 +129,20 @@ All models support censoring: ## 6. Competing Risks Models -These models simulate multiple mutually exclusive event types. Each cause has its -own hazard function, and the observed status indicates which event occurred -(1, 2, ...). The package includes constant-hazard and Weibull-hazard versions. +These models handle scenarios where several distinct failure types can occur. +Each cause has its own hazard function, and the observed status indicates which +event occurred (1, 2, ...). The package includes constant-hazard and +Weibull-hazard versions. ## 7. Mixture Cure Models -Mixture cure models assume a proportion of subjects are immune to the event. The -generator mixes a logistic cure component with an exponential hazard for the -uncured, returning a ``cured`` indicator column alongside the usual time and -status. +These models posit that a subset of the population is cured and will never +experience the event of interest. The generator mixes a logistic cure component +with an exponential hazard for the uncured, returning a ``cured`` indicator +column alongside the usual time and status. ## 8. Piecewise Exponential Model -This model divides the time axis into intervals defined by user-supplied -breakpoints. Each interval has its own hazard rate, allowing flexible hazard -shapes over time. +Here the baseline hazard is assumed constant within each of several +user-specified intervals. This allows flexible hazard shapes over time while +remaining easy to simulate. diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md new file mode 100644 index 0000000..8b17acd --- /dev/null +++ b/docs/source/tutorials/basic_usage.md @@ -0,0 +1,110 @@ +# Basic Usage Tutorial + +This tutorial covers the fundamentals of generating survival data with gen_surv. + +## Your First Dataset + +Let's start with the simplest case - generating data from a Cox proportional hazards model: + +```python +from gen_surv import generate +import pandas as pd + +# Generate basic CPHM data + df = generate( + model="cphm", + n=200, + beta=0.7, + covar=1.5, + model_cens="exponential", + cens_par=2.0, + seed=42 # For reproducibility + ) + +print(f"Dataset shape: {df.shape}") +print(f"Event rate: {df['status'].mean():.2%}") +print("\nFirst 5 rows:") +print(df.head()) +``` + +## Understanding Parameters + +### Common Parameters + +All models share these parameters: + +- `n`: Sample size (number of individuals) +- `model_cens`: Censoring type ("uniform" or "exponential") +- `cens_par`: Censoring distribution parameter +- `seed`: Random seed for reproducibility + +### Model-Specific Parameters + +Each model has unique parameters. For CPHM: + +- `beta`: Covariate effect (hazard ratio = exp(beta)) +- `covar`: Range for uniform covariate generation [0, covar] + +## Censoring Mechanisms + +gen_surv supports two censoring types: + +### Uniform Censoring +```python +# Censoring times uniformly distributed on [0, cens_par] +df_uniform = generate( + model="cphm", + n=100, + beta=0.5, + covar=2.0, + model_cens="uniform", + cens_par=3.0 +) +``` + +### Exponential Censoring +```python +# Censoring times exponentially distributed with mean cens_par +df_exponential = generate( + model="cphm", + n=100, + beta=0.5, + covar=2.0, + model_cens="exponential", + cens_par=2.0 +) +``` + +## Exploring Your Data + +Basic data exploration: + +```python +import matplotlib.pyplot as plt +import seaborn as sns + +# Event rate by covariate level +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + +# Histogram of survival times +ax1.hist(df['time'], bins=20, alpha=0.7, edgecolor='black') +ax1.set_xlabel('Time') +ax1.set_ylabel('Frequency') +ax1.set_title('Distribution of Observed Times') + +# Event rate vs covariate +df['covar_bin'] = pd.cut(df['covariate'], bins=5) +event_rate = df.groupby('covar_bin')['status'].mean() +event_rate.plot(kind='bar', ax=ax2, rot=45) +ax2.set_ylabel('Event Rate') +ax2.set_title('Event Rate by Covariate Level') + +plt.tight_layout() +plt.show() +``` + +## Next Steps + +- Try different models: {doc}`model_comparison` +- Learn advanced features: {doc}`advanced_features` +- See integration examples: {doc}`integration_examples` diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md new file mode 100644 index 0000000..1594339 --- /dev/null +++ b/docs/source/tutorials/index.md @@ -0,0 +1,9 @@ +# Tutorials + +Step-by-step guides for using gen_surv effectively. + +```{toctree} +:maxdepth: 2 + +basic_usage +``` diff --git a/docs/source/usage.md b/docs/source/usage.md index 240a96e..cae0656 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -31,3 +31,15 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv For a full description of available models and parameters, see the API reference. + +## Building the Documentation + +Documentation is written using [Sphinx](https://www.sphinx-doc.org). To build the HTML pages locally run: + +```bash +cd docs +make html +``` + +The generated files will be available under `docs/build/html`. + diff --git a/pyproject.toml b/pyproject.toml index 6622899..991f439 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,12 @@ isort = "^5.13.2" flake8 = "^6.1.0" [tool.poetry.group.docs.dependencies] -sphinx = "^7.2.6" -myst-parser = "<4.0.0" +sphinx = ">=6.0" sphinx-rtd-theme = "^1.3.0" +myst-parser = ">=1.0.0,<4.0.0" +sphinx-autodoc-typehints = "^1.24.0" +sphinx-copybutton = "^0.5.2" +sphinx-design = "^0.5.0" [tool.poetry.scripts] gen_surv = "gen_surv.cli:app" diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 0c14b2d..2b1a4f9 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -1,4 +1,5 @@ import pandas as pd +import pytest from gen_surv.mixture import cure_fraction_estimate, gen_mixture_cure @@ -14,3 +15,69 @@ def test_cure_fraction_estimate_range(): df = gen_mixture_cure(n=50, cure_fraction=0.3, seed=0) est = cure_fraction_estimate(df) assert 0 <= est <= 1 + + +def test_cure_fraction_estimate_empty_returns_zero(): + df = pd.DataFrame(columns=["time", "status"]) + est = cure_fraction_estimate(df) + assert est == 0.0 + + +def test_gen_mixture_cure_invalid_inputs(): + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=1.5) + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, baseline_hazard=0) + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, covariate_dist="bad") + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, model_cens="bad") + with pytest.raises(ValueError): + gen_mixture_cure( + n=5, + cure_fraction=0.2, + betas_survival=[0.1, 0.2], + betas_cure=[0.1], + ) + + +def test_gen_mixture_cure_max_time_cap(): + df = gen_mixture_cure( + n=50, + cure_fraction=0.3, + max_time=5.0, + model_cens="exponential", + cens_par=1.0, + seed=123, + ) + assert (df["time"] <= 5.0).all() + + +def test_cure_fraction_estimate_close_to_true(): + df = gen_mixture_cure(n=200, cure_fraction=0.4, seed=1) + est = cure_fraction_estimate(df) + assert pytest.approx(0.4, abs=0.15) == est + + +def test_gen_mixture_cure_covariate_distributions(): + for dist in ["uniform", "binary"]: + df = gen_mixture_cure(n=20, cure_fraction=0.3, covariate_dist=dist, seed=2) + assert {"time", "status", "cured", "X0", "X1"}.issubset(df.columns) + + +def test_gen_mixture_cure_no_max_time_allows_long_times(): + df = gen_mixture_cure( + n=100, + cure_fraction=0.5, + max_time=None, + model_cens="uniform", + cens_par=20.0, + seed=3, + ) + assert df["time"].max() > 10 + + +def test_cure_fraction_estimate_small_sample(): + df = gen_mixture_cure(n=3, cure_fraction=0.2, seed=4) + est = cure_fraction_estimate(df) + assert 0 <= est <= 1 From 74b402151f5b3cfe36d3505a70f99a77069eb613 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 13:46:56 +0100 Subject: [PATCH 15/19] Increase test coverage (#52) --- .github/workflows/docs.yml | 2 +- tests/test_cli.py | 113 +++++++++++++++++++++++++++++++++- tests/test_competing_risks.py | 31 ++++++++++ tests/test_piecewise.py | 36 +++++++++++ tests/test_summary_extra.py | 60 ++++++++++++++++-- tests/test_visualization.py | 5 +- 6 files changed, 238 insertions(+), 9 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a7a2501..7f2cb1f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,7 +21,7 @@ jobs: - name: Build documentation run: poetry run sphinx-build -W -b html docs/source docs/build - name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: documentation path: docs/build/ diff --git a/tests/test_cli.py b/tests/test_cli.py index bb4537f..d34c83a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,11 +1,12 @@ import os import runpy import sys +import pytest import pandas as pd sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from gen_surv.cli import dataset +from gen_surv.cli import dataset, visualize def test_cli_dataset_stdout(monkeypatch, capsys): @@ -81,7 +82,9 @@ def fake_generate(**kwargs): return pd.DataFrame({"time": [1], "status": [0]}) monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset(model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None) + dataset( + model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None + ) assert captured["model"] == "aft_weibull" assert captured["beta"] == [0.1, 0.2] assert captured["shape"] == 1.1 @@ -146,3 +149,109 @@ def fake_generate(**kwargs): assert captured["betas_survival"] == [0.4] assert captured["betas_cure"] == [0.4] + +def test_dataset_invalid_model(monkeypatch): + def fake_generate(**kwargs): + raise ValueError("bad model") + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + with pytest.raises(ValueError): + dataset(model="nope", n=1, output=None) + + +def test_cli_visualize_basic(monkeypatch, tmp_path): + csv = tmp_path / "data.csv" + pd.DataFrame({"time": [1, 2], "status": [1, 0]}).to_csv(csv, index=False) + + def fake_plot_survival_curve(**kwargs): + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot([0, 1], [1, 0]) + return fig, ax + + monkeypatch.setattr( + "gen_surv.visualization.plot_survival_curve", fake_plot_survival_curve + ) + + saved = [] + + def fake_savefig(path, *args, **kwargs): + saved.append(path) + + monkeypatch.setattr("matplotlib.pyplot.savefig", fake_savefig) + + visualize( + str(csv), + time_col="time", + status_col="status", + group_col=None, + output=str(tmp_path / "plot.png"), + ) + assert saved and saved[0].endswith("plot.png") + + +def test_dataset_aft_log_logistic(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="aft_log_logistic", + n=1, + beta=[0.1], + shape=1.2, + scale=2.3, + output=None, + ) + assert captured["model"] == "aft_log_logistic" + assert captured["beta"] == [0.1] + assert captured["shape"] == 1.2 + assert captured["scale"] == 2.3 + + +def test_dataset_competing_risks_weibull(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="competing_risks_weibull", + n=1, + n_risks=2, + shape_params=[0.7, 1.2], + scale_params=[2.0, 2.0], + beta=0.3, + output=None, + ) + assert captured["n_risks"] == 2 + assert captured["shape_params"] == [0.7, 1.2] + assert captured["scale_params"] == [2.0, 2.0] + assert captured["betas"] == [0.3, 0.3] + + +def test_dataset_piecewise(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="piecewise_exponential", + n=1, + breakpoints=[1.0], + hazard_rates=[0.2, 0.3], + beta=[0.4], + output=None, + ) + assert captured["breakpoints"] == [1.0] + assert captured["hazard_rates"] == [0.2, 0.3] + assert captured["betas"] == [0.4] diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py index 181aa21..10c0574 100644 --- a/tests/test_competing_risks.py +++ b/tests/test_competing_risks.py @@ -8,6 +8,7 @@ from hypothesis import given from hypothesis import strategies as st +import gen_surv.competing_risks as cr from gen_surv.competing_risks import ( cause_specific_cumulative_incidence, gen_competing_risks, @@ -183,3 +184,33 @@ def test_reproducibility(): with pytest.raises(AssertionError): pd.testing.assert_frame_equal(df1, df3) + + +def test_competing_risks_summary_basic(): + df = gen_competing_risks(n=10, n_risks=2, seed=1) + summary = cr.competing_risks_summary(df) + assert summary["n_subjects"] == 10 + assert summary["n_causes"] == 2 + assert set(summary["events_by_cause"]) <= {1, 2} + assert "time_stats" in summary + + +def test_competing_risks_summary_with_categorical(): + df = gen_competing_risks(n=8, n_risks=2, seed=2) + df["group"] = ["A", "B"] * 4 + summary = cr.competing_risks_summary(df, covariate_cols=["X0", "group"]) + assert summary["covariate_stats"]["group"]["categories"] == 2 + assert "distribution" in summary["covariate_stats"]["group"] + + +import matplotlib + +matplotlib.use("Agg") + + +def test_plot_cause_specific_hazards_runs(): + df = gen_competing_risks(n=30, n_risks=2, seed=3) + fig, ax = cr.plot_cause_specific_hazards(df, time_points=np.linspace(0, 5, 5)) + assert hasattr(fig, "savefig") + assert len(ax.get_lines()) >= 1 + matplotlib.pyplot.close(fig) diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py index e587af2..b06903d 100644 --- a/tests/test_piecewise.py +++ b/tests/test_piecewise.py @@ -19,6 +19,7 @@ def test_piecewise_invalid_lengths(): n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42 ) + def test_piecewise_invalid_hazard_and_breakpoints(): with pytest.raises(ValueError): gen_piecewise_exponential( @@ -34,3 +35,38 @@ def test_piecewise_invalid_hazard_and_breakpoints(): hazard_rates=[0.5, -1.0], seed=42, ) + + +def test_piecewise_covariate_distributions(): + for dist, params in [ + ("uniform", {"low": 0.0, "high": 1.0}), + ("binary", {"p": 0.7}), + ]: + df = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.2, 0.4], + covariate_dist=dist, + covariate_params=params, + seed=1, + ) + assert len(df) == 5 + assert {"X0", "X1"}.issubset(df.columns) + + +def test_piecewise_custom_betas_reproducible(): + df1 = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.1, 0.2], + betas=[0.5, -0.2], + seed=2, + ) + df2 = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.1, 0.2], + betas=[0.5, -0.2], + seed=2, + ) + pd.testing.assert_frame_equal(df1, df2) diff --git a/tests/test_summary_extra.py b/tests/test_summary_extra.py index a908a5b..4c78c5e 100644 --- a/tests/test_summary_extra.py +++ b/tests/test_summary_extra.py @@ -1,6 +1,10 @@ import pandas as pd import pytest -from gen_surv.summary import check_survival_data_quality, compare_survival_datasets +from gen_surv.summary import ( + check_survival_data_quality, + compare_survival_datasets, + _print_summary, +) from gen_surv import generate @@ -35,17 +39,65 @@ def test_check_survival_data_quality_no_fix(): def test_compare_survival_datasets_basic(): - ds1 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) - ds2 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=1.0, covariate_range=1.0) + ds1 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + ds2 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=1.0, + covariate_range=1.0, + ) comparison = compare_survival_datasets({"A": ds1, "B": ds2}) assert set(["A", "B"]).issubset(comparison.columns) assert "n_subjects" in comparison.index def test_compare_survival_datasets_with_covariates_and_empty_error(): - ds = generate(model="cphm", n=3, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0) + ds = generate( + model="cphm", + n=3, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"]) assert "only" in comparison.columns assert "X0_mean" in comparison.index with pytest.raises(ValueError): compare_survival_datasets({}) + + +def test_check_survival_data_quality_min_and_max(): + df = pd.DataFrame({"time": [-1.0, 3.0], "status": [1, 1]}) + fixed, issues = check_survival_data_quality( + df, min_time=0.0, max_time=2.0, fix_issues=True + ) + assert (fixed["time"] <= 2.0).all() + assert issues["modifications"]["values_fixed"] > 0 + + +def test_print_summary_with_issues(capsys): + summary = { + "dataset_info": {"n_subjects": 2, "n_unique_ids": 2, "n_covariates": 0}, + "event_info": {"n_events": 1, "n_censored": 1, "event_rate": 0.5}, + "time_info": {"min": 0.0, "max": 2.0, "mean": 1.0, "median": 1.0}, + "data_quality": { + "missing_time": 0, + "missing_status": 0, + "negative_time": 1, + "invalid_status": 0, + }, + "covariates": {}, + } + _print_summary(summary, "time", "status", None, []) + out = capsys.readouterr().out + assert "Issues detected" in out diff --git a/tests/test_visualization.py b/tests/test_visualization.py index f077e27..b3ac616 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -162,11 +162,12 @@ def fake_import(name, *args, **kwargs): def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys): """visualize handles CSV read failures gracefully.""" - monkeypatch.setattr("pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom"))) + monkeypatch.setattr( + "pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom")) + ) csv_path = tmp_path / "x.csv" csv_path.write_text("time,status\n1,1\n") with pytest.raises(typer.Exit): visualize(str(csv_path)) captured = capsys.readouterr() assert "Error loading CSV file" in captured.out - From 908e6d18a8ea6c94c9a5f4542551c16c97432ab1 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 14:10:38 +0100 Subject: [PATCH 16/19] Fix docs build and sort imports (#54) --- docs/requirements.txt | 1 + docs/source/algorithms.md | 4 ++++ docs/source/api/index.md | 4 ++++ docs/source/bibliography.md | 4 ++++ docs/source/changelog.md | 6 +++++- docs/source/conf.py | 4 ++++ docs/source/contributing.md | 6 +++++- docs/source/examples/index.md | 4 ++++ docs/source/getting_started.md | 4 ++++ docs/source/modules.md | 4 ++++ docs/source/rtd.md | 4 ++++ docs/source/theory.md | 4 ++++ docs/source/tutorials/index.md | 4 ++++ docs/source/usage.md | 4 ++++ gen_surv/interface.py | 3 ++- gen_surv/tdcm.py | 19 ++++++++++--------- pyproject.toml | 1 + tests/test_cli.py | 2 +- tests/test_piecewise_functions.py | 1 + tests/test_summary_extra.py | 5 +++-- tests/test_summary_more.py | 5 +++-- 21 files changed, 76 insertions(+), 17 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5cf2520..11d4769 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,3 +4,4 @@ sphinx-rtd-theme sphinx-autodoc-typehints sphinx-copybutton sphinx-design +linkify-it-py>=2.0 diff --git a/docs/source/algorithms.md b/docs/source/algorithms.md index 30ed1fb..d5bcb70 100644 --- a/docs/source/algorithms.md +++ b/docs/source/algorithms.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Algorithm Overview This page provides a short description of each model implemented in **gen_surv**. For mathematical details see {doc}`theory`. diff --git a/docs/source/api/index.md b/docs/source/api/index.md index 0a26c5a..a15a79e 100644 --- a/docs/source/api/index.md +++ b/docs/source/api/index.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # API Reference Complete documentation for all gen_surv functions and classes. diff --git a/docs/source/bibliography.md b/docs/source/bibliography.md index c3ae2c9..33dad70 100644 --- a/docs/source/bibliography.md +++ b/docs/source/bibliography.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # References Below is a selection of references covering the statistical models implemented in **gen_surv**. diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 8019063..011713d 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -1,5 +1,9 @@ +--- +orphan: true +--- + # Changelog ```{include} ../../CHANGELOG.md -:relative: true +:relative-docs: true ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 22a56b1..6eb716c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -64,6 +64,10 @@ 'pandas': ('https://pandas.pydata.org/docs/', None), } +# Disable fetching remote inventories when network access is unavailable +if os.environ.get("SKIP_INTERSPHINX", "1") == "1": + intersphinx_mapping = {} + # HTML theme options html_theme = 'sphinx_rtd_theme' html_theme_options = { diff --git a/docs/source/contributing.md b/docs/source/contributing.md index 37fab23..82121b3 100644 --- a/docs/source/contributing.md +++ b/docs/source/contributing.md @@ -1,5 +1,9 @@ +--- +orphan: true +--- + # Contributing ```{include} ../../CONTRIBUTING.md -:relative: true +:relative-docs: true ``` diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md index 0f58536..6e05f8f 100644 --- a/docs/source/examples/index.md +++ b/docs/source/examples/index.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Examples Real-world examples and use cases for gen_surv. diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index b94f4b4..292d0a0 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Getting Started This guide will help you install gen_surv and generate your first survival dataset. diff --git a/docs/source/modules.md b/docs/source/modules.md index 114a344..73ff5ad 100644 --- a/docs/source/modules.md +++ b/docs/source/modules.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # API Reference ::: gen_surv.cphm diff --git a/docs/source/rtd.md b/docs/source/rtd.md index dc77463..7eac430 100644 --- a/docs/source/rtd.md +++ b/docs/source/rtd.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Read the Docs This project uses [Read the Docs](https://readthedocs.org/) to host its documentation. The site is automatically rebuilt whenever changes are pushed to the repository. diff --git a/docs/source/theory.md b/docs/source/theory.md index 5d58bb1..1101034 100644 --- a/docs/source/theory.md +++ b/docs/source/theory.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # 📘 Mathematical Foundations of `gen_surv` This page presents the mathematical formulation behind the survival models implemented in the `gen_surv` package. diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 1594339..bc3af37 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Tutorials Step-by-step guides for using gen_surv effectively. diff --git a/docs/source/usage.md b/docs/source/usage.md index cae0656..50bd5b5 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Getting Started This page offers a quick introduction to installing and using **gen_surv**. diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 8a8943e..ca87292 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -58,7 +58,8 @@ def generate(model: str, **kwargs: Any) -> pd.DataFrame: ``tdcm``, ``thmm``, ``aft_ln``, ``aft_weibull``, ``aft_log_logistic``, ``competing_risks``, ``competing_risks_weibull``, ``mixture_cure``, or ``piecewise_exponential``. - **kwargs: Arguments forwarded to the chosen generator. These vary by model: + **kwargs: Arguments forwarded to the chosen generator. These vary by model. + - cphm: n, model_cens, cens_par, beta, covariate_range - cmm: n, model_cens, cens_par, beta, covariate_range, rate - tdcm: n, dist, corr, dist_par, model_cens, cens_par, beta, lam diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py index 9ffa081..3ea538c 100644 --- a/gen_surv/tdcm.py +++ b/gen_surv/tdcm.py @@ -11,17 +11,18 @@ def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, Generate censored TDCM observations. Parameters: - - n (int): Number of individuals - - dist_par (list): Not directly used here (kept for API compatibility) - - model_cens (str): "uniform" or "exponential" - - cens_par (float): Parameter for the censoring model - - beta (list): Length-2 list of regression coefficients - - lam (float): Rate parameter - - b (np.ndarray): Covariate matrix with 2 columns [., z1] + + - n (int): Number of individuals + - dist_par (list): Not directly used here (kept for API compatibility) + - model_cens (str): "uniform" or "exponential" + - cens_par (float): Parameter for the censoring model + - beta (list): Length-2 list of regression coefficients + - lam (float): Rate parameter + - b (np.ndarray): Covariate matrix with 2 columns [., z1] Returns: - - np.ndarray: Shape (n, 6) with columns: - [id, start, stop, status, covariate1 (z1), covariate2 (z2)] + - np.ndarray: Shape (n, 6) with columns: + [id, start, stop, status, covariate1 (z1), covariate2 (z2)] """ rfunc = runifcens if model_cens == "uniform" else rexpocens observations = np.zeros((n, 6)) diff --git a/pyproject.toml b/pyproject.toml index 991f439..cf38e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ myst-parser = ">=1.0.0,<4.0.0" sphinx-autodoc-typehints = "^1.24.0" sphinx-copybutton = "^0.5.2" sphinx-design = "^0.5.0" +linkify-it-py = ">=2.0" [tool.poetry.scripts] gen_surv = "gen_surv.cli:app" diff --git a/tests/test_cli.py b/tests/test_cli.py index d34c83a..1ae5ea5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,9 @@ import os import runpy import sys -import pytest import pandas as pd +import pytest sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.cli import dataset, visualize diff --git a/tests/test_piecewise_functions.py b/tests/test_piecewise_functions.py index 648e416..b1de35a 100644 --- a/tests/test_piecewise_functions.py +++ b/tests/test_piecewise_functions.py @@ -1,4 +1,5 @@ import numpy as np + from gen_surv.piecewise import piecewise_hazard_function, piecewise_survival_function diff --git a/tests/test_summary_extra.py b/tests/test_summary_extra.py index 4c78c5e..380b818 100644 --- a/tests/test_summary_extra.py +++ b/tests/test_summary_extra.py @@ -1,11 +1,12 @@ import pandas as pd import pytest + +from gen_surv import generate from gen_surv.summary import ( + _print_summary, check_survival_data_quality, compare_survival_datasets, - _print_summary, ) -from gen_surv import generate def test_check_survival_data_quality_fix_issues(): diff --git a/tests/test_summary_more.py b/tests/test_summary_more.py index be730da..e039eae 100644 --- a/tests/test_summary_more.py +++ b/tests/test_summary_more.py @@ -1,9 +1,10 @@ import pandas as pd import pytest + from gen_surv.summary import ( - summarize_survival_dataset, - check_survival_data_quality, _print_summary, + check_survival_data_quality, + summarize_survival_dataset, ) From c10e62f9935d7587bcfda50e2f43e96b0d6b39ef Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Wed, 30 Jul 2025 22:06:35 +0100 Subject: [PATCH 17/19] Add scikit-survival integration (#56) --- .github/workflows/test.yml | 5 ++- CHANGELOG.md | 31 ++++++++++++++ README.md | 27 ++++++++++-- TODO.md | 36 ++++++++-------- docs/source/bibliography.md | 49 ++++++++++++++++++---- docs/source/conf.py | 5 +-- docs/source/getting_started.md | 4 +- docs/source/index.md | 4 +- docs/source/tutorials/basic_usage.md | 18 ++++---- docs/source/usage.md | 38 ++++++++++++++++- fix_recommendations.md | 46 -------------------- gen_surv/__init__.py | 4 ++ gen_surv/export.py | 5 ++- gen_surv/integration.py | 36 ++++++++++++++++ gen_surv/sklearn_adapter.py | 34 +++++++++++++++ pyproject.toml | 1 + tests/test_cli_integration.py | 27 ++++++++++++ tests/test_export.py | 63 +++++----------------------- tests/test_integration_sksurv.py | 10 +++++ tests/test_piecewise.py | 34 +++++++++++++++ tests/test_sklearn_adapter.py | 31 ++++++++++++++ 21 files changed, 360 insertions(+), 148 deletions(-) delete mode 100644 fix_recommendations.md create mode 100644 gen_surv/integration.py create mode 100644 gen_surv/sklearn_adapter.py create mode 100644 tests/test_cli_integration.py create mode 100644 tests/test_integration_sksurv.py create mode 100644 tests/test_sklearn_adapter.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bcd9349..a2c554e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,9 @@ on: jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout code @@ -17,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.9" + python-version: ${{ matrix.python-version }} - name: Install Poetry run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 500ee3d..1605e6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # CHANGELOG +## v1.0.9 (Unreleased) + +### Features +- export datasets to RDS files +- test workflow runs on a Python version matrix +- scikit-learn compatible data generator +- compatibility helpers for lifelines and scikit-survival + +### Documentation +- updated usage examples and tutorials + +### Misc +- README quick example uses `covariate_range` + +## v1.0.8 (2025-07-30) + +### Documentation +- ensure absolute path resolution in `conf.py` +- drop unsupported theme option +- define bibliography anchors and headings +- fix tutorial links to non-existing docs +- add additional references to the bibliography + +### Testing +- add CLI integration test +- expand piecewise generator test coverage + +### Misc +- remove fix_recommendations.md + + ## v1.0.0 (2025-06-06) diff --git a/README.md b/README.md index 4ebbdb0..8e93f36 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ - Mixture cure and piecewise exponential models - Competing risks generators (constant and Weibull hazards) - Command-line interface and export utilities +- Scikit-learn compatible data generator +- Conversion helper for scikit-survival and lifelines ## Installation @@ -40,11 +42,30 @@ poetry install ## Quick Example ```python -from gen_surv import generate +from gen_surv import export_dataset, generate # basic Cox proportional hazards data -sim = generate(model="cphm", n=100, beta=0.5, covar=2.0, - model_cens="uniform", cens_par=1.0) +sim = generate( + model="cphm", + n=100, + beta=0.5, + covariate_range=2.0, + model_cens="uniform", + cens_par=1.0, +) + +# save to an RDS file +export_dataset(sim, "survival_data.rds") +``` + +You can also convert the resulting DataFrame for use with +[scikit-survival](https://scikit-survival.readthedocs.io) or +[lifelines](https://lifelines.readthedocs.io): + +```python +from gen_surv import to_sksurv + +sks_dataset = to_sksurv(sim) ``` See the [usage guide](docs/source/getting_started.md) for more examples. diff --git a/TODO.md b/TODO.md index 91b4fe4..49b7074 100644 --- a/TODO.md +++ b/TODO.md @@ -9,9 +9,9 @@ This document outlines future enhancements, features, and ideas for improving th - [✅] Add property-based tests using Hypothesis to cover edge cases - [✅] Build a CLI for generating datasets from the terminal - [ ] Expand documentation with multilingual support and more usage examples -- [ ] Implement Weibull and log-logistic AFT models and add visualization utilities +- [✅] Implement Weibull and log-logistic AFT models and add visualization utilities - [✅] Provide CITATION metadata for proper referencing -- [ ] Ensure all functions include Google-style docstrings with inline comments +- [✅] Ensure all functions include Google-style docstrings with inline comments --- @@ -37,17 +37,17 @@ This document outlines future enhancements, features, and ideas for improving th - [✅] Add tests for each model (e.g., `test_tdcm.py`, `test_thmm.py`, `test_aft.py`) - [✅] Add property-based tests with `hypothesis` -- [ ] Cover edge cases (e.g., invalid parameters, n=0, negative censoring) -- [ ] Run tests on multiple Python versions (CI matrix) +- [✅] Cover edge cases (e.g., invalid parameters, n=0, negative censoring) +- [✅] Run tests on multiple Python versions (CI matrix) --- ## 🧠 4. Advanced Models -- [ ] Add Piecewise Exponential Model support -- [ ] Add competing risks / multi-event simulation +- [✅] Add Piecewise Exponential Model support +- [✅] Add competing risks / multi-event simulation - [✅] Implement parametric AFT models (log-normal) -- [ ] Implement parametric AFT models (log-logistic, weibull) +- [✅] Implement parametric AFT models (log-logistic, weibull) - [ ] Simulate time-varying hazards - [ ] Add informative or covariate-dependent censoring @@ -55,17 +55,17 @@ This document outlines future enhancements, features, and ideas for improving th ## 📊 5. Visualization and Analysis -- [ ] Create `plot_survival(df, model=...)` utilities -- [ ] Create `describe_survival(df)` summary helpers -- [ ] Export data to CSV / JSON / Feather +- [✅] Create `plot_survival(df, model=...)` utilities +- [✅] Create `describe_survival(df)` summary helpers +- [✅] Export data to CSV / JSON / Feather --- ## 🌍 6. Ecosystem Integration -- [ ] Add a `GenSurvDataGenerator` compatible with `sklearn` -- [ ] Enable use with `lifelines`, `scikit-survival`, `sksurv` -- [ ] Export in R-compatible formats (.csv, .rds) +- [✅] Add a `GenSurvDataGenerator` compatible with `sklearn` +- [✅] Enable use with `lifelines`, `scikit-survival`, `sksurv` +- [✅] Export in R-compatible formats (.csv, .rds) --- @@ -80,12 +80,12 @@ This document outlines future enhancements, features, and ideas for improving th ## 🧠 8. New Survival Models to Implement - [✅] Log-Normal AFT -- [ ] Log-Logistic AFT -- [ ] Weibull AFT -- [ ] Piecewise Exponential -- [ ] Competing Risks +- [✅] Log-Logistic AFT +- [✅] Weibull AFT +- [✅] Piecewise Exponential +- [✅] Competing Risks - [ ] Recurrent Events -- [ ] Mixture Cure Model +- [✅] Mixture Cure Model --- diff --git a/docs/source/bibliography.md b/docs/source/bibliography.md index 33dad70..5a5285a 100644 --- a/docs/source/bibliography.md +++ b/docs/source/bibliography.md @@ -6,21 +6,54 @@ orphan: true Below is a selection of references covering the statistical models implemented in **gen_surv**. -.. _Cox1972: +(Cox1972)= +## Cox (1972) Cox, D. R. (1972). Regression Models and Life-Tables. *Journal of the Royal Statistical Society: Series B*, 34(2), 187-220. -.. _Farewell1982: +(Farewell1982)= +## Farewell (1982) Farewell, V.T. (1982). The Use of Mixture Models for the Analysis of Survival Data with Long-Term Survivors. *Biometrics*, 38(4), 1041-1046. -.. _FineGray1999: +(FineGray1999)= +## Fine and Gray (1999) Fine, J.P., & Gray, R.J. (1999). A Proportional Hazards Model for the Subdistribution of a Competing Risk. *Journal of the American Statistical Association*, 94(446), 496-509. -.. _Andersen1993: +(Andersen1993)= +## Andersen et al. (1993) Andersen, P.K., Borgan, Ø., Gill, R.D., & Keiding, N. (1993). *Statistical Models Based on Counting Processes*. Springer. -.. _Zucchini2017: +(Zucchini2017)= +## Zucchini et al. (2017) Zucchini, W., MacDonald, I.L., & Langrock, R. (2017). *Hidden Markov Models for Time Series*. Chapman and Hall/CRC. -- Klein, J.P., & Moeschberger, M.L. (2003). *Survival Analysis: Techniques for Censored and Truncated Data*. Springer. -- Kalbfleisch, J.D., & Prentice, R.L. (2002). *The Statistical Analysis of Failure Time Data*. Wiley. -- Cook, R.J., & Lawless, J.F. (2007). *The Statistical Analysis of Recurrent Events*. Springer. +(KleinMoeschberger2003)= +## Klein and Moeschberger (2003) +Klein, J.P., & Moeschberger, M.L. (2003). *Survival Analysis: Techniques for Censored and Truncated Data*. Springer. + +(KalbfleischPrentice2002)= +## Kalbfleisch and Prentice (2002) +Kalbfleisch, J.D., & Prentice, R.L. (2002). *The Statistical Analysis of Failure Time Data*. Wiley. + +(CookLawless2007)= +## Cook and Lawless (2007) +Cook, R.J., & Lawless, J.F. (2007). *The Statistical Analysis of Recurrent Events*. Springer. + +(KaplanMeier1958)= +## Kaplan and Meier (1958) +Kaplan, E.L., & Meier, P. (1958). Nonparametric Estimation from Incomplete Observations. *Journal of the American Statistical Association*, 53(282), 457-481. +(TherneauGrambsch2000)= +## Therneau and Grambsch (2000) +Therneau, T.M., & Grambsch, P.M. (2000). *Modeling Survival Data: Extending the Cox Model*. Springer. + +(FlemingHarrington1991)= +## Fleming and Harrington (1991) +Fleming, T.R., & Harrington, D.P. (1991). *Counting Processes and Survival Analysis*. Wiley. + +(Collett2015)= +## Collett (2015) +Collett, D. (2015). *Modelling Survival Data in Medical Research*. CRC Press. + +(KleinbaumKlein2012)= +## Kleinbaum and Klein (2012) +Kleinbaum, D.G., & Klein, M. (2012). *Survival Analysis: A Self-Learning Text*. Springer. + diff --git a/docs/source/conf.py b/docs/source/conf.py index 6eb716c..2b0d9fd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2,8 +2,8 @@ import sys from pathlib import Path -# Add the package to the Python path -project_root = Path(__file__).parent.parent.parent +# Add the package to the Python path using an absolute path +project_root = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(project_root / "gen_surv")) # Project information @@ -74,7 +74,6 @@ 'canonical_url': 'https://gensurvpy.readthedocs.io/', 'analytics_id': '', 'logo_only': False, - 'display_version': True, 'prev_next_buttons_location': 'bottom', 'style_external_links': False, 'style_nav_header_background': '#2980B9', diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 292d0a0..2ec1a03 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -33,8 +33,8 @@ from gen_surv import generate df = generate( model="cphm", # Model type n=100, # Sample size - beta=0.5, # Covariate effect - covar=2.0, # Covariate range + beta=0.5, # Covariate effect + covariate_range=2.0, # Covariate range model_cens="uniform", # Censoring type cens_par=3.0 # Censoring parameter ) diff --git a/docs/source/index.md b/docs/source/index.md index 9d52cb7..d839e4e 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -22,7 +22,7 @@ pip install gen-surv Generate your first dataset: ```python from gen_surv import generate -df = generate(model="cphm", n=100, beta=0.5, covar=2.0) +df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0) ``` ``` @@ -72,7 +72,7 @@ df = gs.generate( model="cphm", n=500, beta=0.5, - covar=2.0, + covariate_range=2.0, model_cens="uniform", cens_par=3.0 ) diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md index 8b17acd..e99fe2d 100644 --- a/docs/source/tutorials/basic_usage.md +++ b/docs/source/tutorials/basic_usage.md @@ -15,7 +15,7 @@ import pandas as pd model="cphm", n=200, beta=0.7, - covar=1.5, + covariate_range=1.5, model_cens="exponential", cens_par=2.0, seed=42 # For reproducibility @@ -43,7 +43,7 @@ All models share these parameters: Each model has unique parameters. For CPHM: - `beta`: Covariate effect (hazard ratio = exp(beta)) -- `covar`: Range for uniform covariate generation [0, covar] +- `covariate_range`: Range for uniform covariate generation [0, covariate_range] ## Censoring Mechanisms @@ -56,7 +56,7 @@ df_uniform = generate( model="cphm", n=100, beta=0.5, - covar=2.0, + covariate_range=2.0, model_cens="uniform", cens_par=3.0 ) @@ -69,7 +69,7 @@ df_exponential = generate( model="cphm", n=100, beta=0.5, - covar=2.0, + covariate_range=2.0, model_cens="exponential", cens_par=2.0 ) @@ -93,8 +93,8 @@ ax1.set_ylabel('Frequency') ax1.set_title('Distribution of Observed Times') # Event rate vs covariate -df['covar_bin'] = pd.cut(df['covariate'], bins=5) -event_rate = df.groupby('covar_bin')['status'].mean() +df['covariate_bin'] = pd.cut(df['covariate'], bins=5) +event_rate = df.groupby('covariate_bin')['status'].mean() event_rate.plot(kind='bar', ax=ax2, rot=45) ax2.set_ylabel('Event Rate') ax2.set_title('Event Rate by Covariate Level') @@ -105,6 +105,6 @@ plt.show() ## Next Steps -- Try different models: {doc}`model_comparison` -- Learn advanced features: {doc}`advanced_features` -- See integration examples: {doc}`integration_examples` +- Try different models (model_comparison) +- Learn advanced features (advanced_features) +- See integration examples (integration_examples) diff --git a/docs/source/usage.md b/docs/source/usage.md index 50bd5b5..b567d8a 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -21,10 +21,20 @@ This will create a virtual environment and install all required packages. Generate datasets directly in Python: ```python -from gen_surv import generate +from gen_surv import export_dataset, generate # Cox Proportional Hazards example -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) +df = generate( + model="cphm", + n=100, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, +) + +# Save to RDS for use in R +export_dataset(df, "simulated_data.rds") ``` You can also generate data from the command line: @@ -47,3 +57,27 @@ make html The generated files will be available under `docs/build/html`. +## Scikit-learn Integration + +You can wrap the generator in a transformer compatible with scikit-learn: + +```python +from gen_surv import GenSurvDataGenerator + +est = GenSurvDataGenerator("cphm", n=10, beta=0.5, covariate_range=1.0) +df = est.fit_transform() +``` + +## Lifelines and scikit-survival + +Datasets generated with **gen_surv** can be directly used with +[lifelines](https://lifelines.readthedocs.io). For +[scikit-survival](https://scikit-survival.readthedocs.io) you can convert the +DataFrame using ``to_sksurv``: + +```python +from gen_surv import to_sksurv + +struct = to_sksurv(df) +``` + diff --git a/fix_recommendations.md b/fix_recommendations.md deleted file mode 100644 index 942ff54..0000000 --- a/fix_recommendations.md +++ /dev/null @@ -1,46 +0,0 @@ -# Fixing gen_surv Repository Issues - -## Priority 1: Critical Fixes - -- [x] **Fix `__init__.py` Import Issues** - - Ensure missing imports for new generators are added and exported via `__all__`. - -- [x] **Add Missing Validators** - - Create validation helpers for AFT Weibull, AFT log-logistic, and competing risks generators. - -- [x] **Update CLI Integration** - - Support competing risks, mixture cure, and piecewise exponential models. - -## Priority 2: Version Consistency - -- [x] **Update Version Numbers** - - `CITATION.cff` and `docs/source/conf.py` now reference version 1.0.8. - -## Priority 3: Testing and Documentation - -- [x] **Add Missing Tests** - - Added tests for censoring helpers, mixture cure, piecewise exponential, summary, and visualization. - -- [x] **Update Documentation** - - Documented competing risks, mixture cure, and piecewise exponential models. - -## Priority 4: Code Quality Improvements - -- [x] **Standardize Parameter Naming** - - Replaced the `covar` parameter with `covariate_range` and standardized return columns to `X0`. - -- [x] **Add Type Hints** - - Completed type hints for public functions in `mixture.py`, `piecewise.py`, and `summary.py`. - -## Verification Steps -- [x] `python -c "from gen_surv import gen_aft_log_logistic, gen_competing_risks"` -- [x] `python -m gen_surv dataset competing_risks --n 10` -- [x] `pytest -q` -- [x] `python scripts/check_version_match.py` -- [x] `sphinx-build docs/source docs/build` - -## Status - -All fix recommendations have been implemented in version 1.0.8. - -Verified as of commit `7daa3e1`. diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 47a59bf..cbaa5cf 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -16,6 +16,8 @@ # Individual generators from .cphm import gen_cphm from .export import export_dataset +from .integration import to_sksurv +from .sklearn_adapter import GenSurvDataGenerator # Main interface from .interface import generate @@ -64,6 +66,8 @@ "runifcens", "rexpocens", "export_dataset", + "to_sksurv", + "GenSurvDataGenerator", ] # Add visualization tools to __all__ if available diff --git a/gen_surv/export.py b/gen_surv/export.py index 6ad17f0..c3751bb 100644 --- a/gen_surv/export.py +++ b/gen_surv/export.py @@ -10,6 +10,7 @@ from typing import Optional import pandas as pd +import pyreadr def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> None: @@ -22,7 +23,7 @@ def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> No path : str File path to write to. The extension is used to infer the format when ``fmt`` is ``None``. - fmt : {"csv", "json", "feather"}, optional + fmt : {"csv", "json", "feather", "rds"}, optional Format to use. If omitted, inferred from ``path``. Raises @@ -39,5 +40,7 @@ def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> No df.to_json(path, orient="table") elif fmt in {"feather", "ft"}: df.reset_index(drop=True).to_feather(path) + elif fmt == "rds": + pyreadr.write_rds(path, df.reset_index(drop=True)) else: raise ValueError(f"Unsupported export format: {fmt}") diff --git a/gen_surv/integration.py b/gen_surv/integration.py new file mode 100644 index 0000000..0c7994c --- /dev/null +++ b/gen_surv/integration.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import pandas as pd + + +def to_sksurv(df: pd.DataFrame, time_col: str = "time", event_col: str = "status"): + """Convert a DataFrame to a scikit-survival structured array. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing survival data. + time_col : str, default "time" + Column storing durations. + event_col : str, default "status" + Column storing event indicators (1=event, 0=censored). + + Returns + ------- + numpy.ndarray + Structured array suitable for scikit-survival estimators. + + Notes + ----- + The ``sksurv`` package is imported lazily inside the function. It must be + installed separately, for instance with ``pip install scikit-survival``. + """ + + try: + from sksurv.util import Surv + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "scikit-survival is required for this feature." + ) from exc + + return Surv.from_dataframe(event_col, time_col, df) diff --git a/gen_surv/sklearn_adapter.py b/gen_surv/sklearn_adapter.py new file mode 100644 index 0000000..f8f8141 --- /dev/null +++ b/gen_surv/sklearn_adapter.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Any, Optional + +from .interface import generate + +try: # pragma: no cover - only imported if sklearn is installed + from sklearn.base import BaseEstimator +except Exception: # pragma: no cover - fallback when sklearn missing + class BaseEstimator: # type: ignore + """Minimal stub if scikit-learn is not installed.""" + + +class GenSurvDataGenerator(BaseEstimator): + """Scikit-learn compatible wrapper around :func:`gen_surv.generate`.""" + + def __init__(self, model: str, return_type: str = "df", **kwargs: Any) -> None: + self.model = model + self.return_type = return_type + self.kwargs = kwargs + + def fit(self, X: Optional[Any] = None, y: Optional[Any] = None) -> "GenSurvDataGenerator": + return self + + def transform(self, X: Optional[Any] = None) -> Any: + df = generate(self.model, **self.kwargs) + if self.return_type == "df": + return df + if self.return_type == "dict": + return df.to_dict(orient="list") + raise ValueError("return_type must be 'df' or 'dict'") + + def fit_transform(self, X: Optional[Any] = None, y: Optional[Any] = None, **fit_params: Any) -> Any: + return self.fit(X, y).transform(X) diff --git a/pyproject.toml b/pyproject.toml index cf38e65..89354e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ typer = "^0.12.3" matplotlib = "^3.10" lifelines = "^0.30" pyarrow = "^14" +pyreadr = "^0.5" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py new file mode 100644 index 0000000..7a457a8 --- /dev/null +++ b/tests/test_cli_integration.py @@ -0,0 +1,27 @@ +from typer.testing import CliRunner +import pandas as pd + +from gen_surv.cli import app + + +def test_dataset_cli_integration(tmp_path): + """Run dataset command end-to-end and verify CSV output.""" + runner = CliRunner() + out_file = tmp_path / "data.csv" + result = runner.invoke(app, [ + "dataset", + "cphm", + "--n", + "3", + "--beta", + "0.5", + "--covariate-range", + "1.0", + "-o", + str(out_file), + ]) + assert result.exit_code == 0 + assert out_file.exists() + df = pd.read_csv(out_file) + assert len(df) == 3 + assert {"time", "status"}.issubset(df.columns) diff --git a/tests/test_export.py b/tests/test_export.py index f89e94f..21db041 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,57 +1,14 @@ -import os - import pandas as pd -import pytest - -from gen_surv import export_dataset, generate - - -def test_export_dataset_csv(tmp_path): - df = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - out_file = tmp_path / "data.csv" - export_dataset(df, str(out_file)) - assert out_file.exists() - loaded = pd.read_csv(out_file) - pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) - - -def test_export_dataset_json(tmp_path): - df = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - out_file = tmp_path / "data.json" - export_dataset(df, str(out_file)) - assert out_file.exists() - loaded = pd.read_json(out_file, orient="table") - pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) +import pyreadr +from gen_surv.export import export_dataset -def test_export_dataset_feather_and_invalid(tmp_path): - df = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - feather_file = tmp_path / "data.feather" - export_dataset(df, str(feather_file)) - assert feather_file.exists() - loaded = pd.read_feather(feather_file) - pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded) - with pytest.raises(ValueError): - export_dataset(df, str(tmp_path / "data.txt"), fmt="txt") +def test_export_dataset_rds(tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + out = tmp_path / "data.rds" + export_dataset(df, out) + assert out.exists() + result = pyreadr.read_r(out)[None] + result = result.astype(df.dtypes.to_dict()) + pd.testing.assert_frame_equal(result.reset_index(drop=True), df) diff --git a/tests/test_integration_sksurv.py b/tests/test_integration_sksurv.py new file mode 100644 index 0000000..7f6c2e8 --- /dev/null +++ b/tests/test_integration_sksurv.py @@ -0,0 +1,10 @@ +import pandas as pd + +from gen_surv.integration import to_sksurv + + +def test_to_sksurv(): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + arr = to_sksurv(df) + assert arr.dtype.names == ("status", "time") + assert arr.shape[0] == 2 diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py index b06903d..7dba747 100644 --- a/tests/test_piecewise.py +++ b/tests/test_piecewise.py @@ -70,3 +70,37 @@ def test_piecewise_custom_betas_reproducible(): seed=2, ) pd.testing.assert_frame_equal(df1, df2) + + +def test_piecewise_invalid_covariate_dist(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, 1.0], + covariate_dist="unknown", + seed=1, + ) + + +def test_piecewise_invalid_censoring_model(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, 1.0], + model_cens="bad", + seed=1, + ) + + + +def test_piecewise_negative_breakpoint(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[-1.0], + hazard_rates=[0.5, 1.0], + seed=1, + ) + diff --git a/tests/test_sklearn_adapter.py b/tests/test_sklearn_adapter.py new file mode 100644 index 0000000..669cffd --- /dev/null +++ b/tests/test_sklearn_adapter.py @@ -0,0 +1,31 @@ +from gen_surv.sklearn_adapter import GenSurvDataGenerator + + +def test_sklearn_generator_dataframe(): + gen = GenSurvDataGenerator( + "cphm", + n=4, + beta=0.2, + covariate_range=1.0, + model_cens="uniform", + cens_par=1.0, + ) + df = gen.fit_transform() + assert len(df) == 4 + assert {"time", "status"}.issubset(df.columns) + + +def test_sklearn_generator_dict(): + gen = GenSurvDataGenerator( + "cphm", + return_type="dict", + n=3, + beta=0.5, + covariate_range=1.0, + model_cens="uniform", + cens_par=1.0, + ) + data = gen.transform() + assert isinstance(data, dict) + assert set(data.keys()) >= {"time", "status"} + assert len(data["time"]) == 3 From aa6ee6d453030b47318990e945a9e1b6700cbf87 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Sat, 2 Aug 2025 06:10:44 +0100 Subject: [PATCH 18/19] chore: finalize 1.0.9 release metadata (#58) --- .github/workflows/bump-version.yml | 31 ++------- .github/workflows/ci.yml | 17 ++--- .github/workflows/test.yml | 2 +- .pre-commit-config.yaml | 17 +++++ CHANGELOG.md | 6 +- CHECKLIST.md | 102 +++++++++++++++++++++++++++++ CITATION.cff | 4 +- CONTRIBUTING.md | 11 +++- LICENCE => LICENSE | 0 README.md | 20 +++++- docs/source/api/index.md | 6 ++ docs/source/conf.py | 56 ++++++++-------- docs/source/examples/index.md | 6 ++ docs/source/getting_started.md | 7 ++ docs/source/index.md | 10 ++- docs/source/usage.md | 6 ++ gen_surv/__init__.py | 2 +- gen_surv/integration.py | 4 +- gen_surv/sklearn_adapter.py | 9 ++- pyproject.toml | 3 +- scripts/check_version_match.py | 7 +- tests/test_cli_integration.py | 29 ++++---- tests/test_integration_sksurv.py | 3 + tests/test_piecewise.py | 2 - 24 files changed, 260 insertions(+), 100 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 CHECKLIST.md rename LICENCE => LICENSE (100%) diff --git a/.github/workflows/bump-version.yml b/.github/workflows/bump-version.yml index 841b320..f91dd32 100644 --- a/.github/workflows/bump-version.yml +++ b/.github/workflows/bump-version.yml @@ -1,45 +1,26 @@ # .github/workflows/bump-version.yml -name: Bump Version on Merge to Main - +name: Tag Version on Merge to Main + on: push: branches: - main jobs: - bump-version: + tag-version: runs-on: ubuntu-latest permissions: contents: write - id-token: write steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all history for all branches and tags + fetch-depth: 0 - uses: actions/setup-python@v5 with: python-version: "3.11" - - name: Install Poetry - run: pip install poetry - - - name: Install python-semantic-release - run: pip install python-semantic-release - - - name: Configure Git - run: | - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - - - name: Run Semantic Release - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - # Run semantic-release to get the next version - semantic-release version + - name: Tag repository from pyproject + run: python scripts/check_version_match.py --fix - - name: Push changes - run: | - git push --follow-tags diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 996e268..e05dfe7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Install dependencies - run: poetry install + run: poetry install --with dev - name: Run tests with coverage run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term @@ -59,16 +59,7 @@ jobs: echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Install dependencies - run: poetry install + run: poetry install --with dev - - name: Run black - run: poetry run black --check gen_surv tests examples - - - name: Run isort - run: poetry run isort --check gen_surv tests examples - - - name: Run flake8 - run: poetry run flake8 gen_surv tests examples - - - name: Run mypy - run: poetry run mypy gen_surv + - name: Run pre-commit checks + run: poetry run pre-commit run --all-files diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a2c554e..1245225 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Install dependencies - run: poetry install + run: poetry install --with dev - name: Run tests run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..20fd750 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.1.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index 1605e6b..a8d40fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # CHANGELOG -## v1.0.9 (Unreleased) +## v1.0.9 (2025-08-02) ### Features - export datasets to RDS files @@ -10,6 +10,10 @@ ### Documentation - updated usage examples and tutorials +- document optional scikit-survival dependency throughout the docs + +### Continuous Integration +- auto-tag releases using the version check script ### Misc - README quick example uses `covariate_range` diff --git a/CHECKLIST.md b/CHECKLIST.md new file mode 100644 index 0000000..0163f08 --- /dev/null +++ b/CHECKLIST.md @@ -0,0 +1,102 @@ +# ✅ Python Package Development Checklist + +A checklist to ensure quality, maintainability, and usability of your Python package. + +--- + +## 1. Purpose & Scope + +- [ ] Clear purpose and use cases defined +- [ ] Scoped to a specific problem/domain +- [ ] Project name is meaningful and not taken on PyPI + +--- + +## 2. Project Structure + +- [ ] Uses `src/` layout or appropriate flat structure +- [ ] All package folders contain `__init__.py` +- [ ] Main configuration handled via `pyproject.toml` +- [ ] Includes standard files: `README.md`, `LICENSE`, `.gitignore`, `CHANGELOG.md` + +--- + +## 3. Dependencies + +- [ ] All dependencies declared in `pyproject.toml` or `requirements.txt` +- [ ] Development dependencies separated from runtime dependencies +- [ ] Uses minimal, necessary dependencies only + +--- + +## 4. Code Quality + +- [ ] Follows PEP 8 formatting +- [ ] Imports sorted with `isort` or `ruff` +- [ ] No linter warnings (`ruff`, `flake8`, etc.) +- [ ] Fully typed using `typing` module +- [ ] No unresolved TODOs or FIXME comments + +--- + +## 5. Function & Module Design + +- [ ] Functions are small, pure, and single-responsibility +- [ ] Classes follow clear and simple roles +- [ ] Global state is avoided +- [ ] Public API defined explicitly (e.g. via `__all__`) + +--- + +## 6. Documentation + +- [ ] `README.md` includes overview, install, usage, contributing +- [ ] All functions/classes include docstrings (Google/NumPy style) +- [ ] API reference documentation auto-generated (e.g., Sphinx, MkDocs) +- [ ] Optional: `docs/` folder for additional documentation or site generator + +--- + +## 7. Testing + +- [ ] Unit and integration tests implemented +- [ ] Test coverage > 80% verified by `coverage` +- [ ] Tests are fast and deterministic +- [ ] Continuous Integration (CI) configured to run tests + +--- + +## 8. Versioning & Releases + +- [ ] Uses semantic versioning (MAJOR.MINOR.PATCH) +- [ ] Git tags created for releases +- [ ] `CHANGELOG.md` updated with each release +- [ ] Local build verified (`poetry build`, `hatch build`, or equivalent) +- [ ] Can be published to PyPI and/or TestPyPI + +--- + +## 9. CLI or Scripts (Optional) + +- [ ] CLI entrypoint works correctly (`__main__.py` or `entry_points`) +- [ ] CLI provides helpful messages (`--help`) and handles errors gracefully + +--- + +## 10. Examples / Tutorials + +- [ ] Usage examples included in `README.md` or `examples/` +- [ ] Optional: Jupyter notebooks with demonstrations +- [ ] Optional: Colab or Binder links for live usage + +--- + +## 11. Licensing & Attribution + +- [ ] LICENSE file included (MIT, Apache 2.0, GPL, etc.) +- [ ] Author and contributors credited in `README.md` +- [ ] Optional: `CITATION.cff` file for academic citation + +--- + +> You can duplicate this file for each new package or use it as a GitHub issue template for release checklists. diff --git a/CITATION.cff b/CITATION.cff index 08db27f..78d9c81 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,7 +5,7 @@ message: "If you use this software, please cite it using the metadata below." preferred-citation: type: software title: "gen_surv" - version: "1.0.8" + version: "1.0.9" url: "https://github.com/DiogoRibeiro7/genSurvPy" authors: - family-names: Ribeiro @@ -15,5 +15,5 @@ preferred-citation: affiliation: "ESMAD - Instituto Politécnico do Porto" email: "dfr@esmad.ipp.pt" license: "MIT" - date-released: "2024-01-01" + date-released: "2025-08-02" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 07dc4b6..aba48fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,9 +5,14 @@ Thank you for taking the time to contribute to **gen_surv**! This document provi ## Getting Started 1. Fork the repository and create your feature branch from `main`. -2. Install dependencies with `poetry install`. -3. Ensure the test suite passes with `poetry run pytest`. -4. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. +2. Install dependencies with `poetry install --with dev`. + This installs all packages needed for development, including + the optional dependency `scikit-survival`. + On Debian/Ubuntu you may need `build-essential gfortran libopenblas-dev` + to build it. +3. Run `pre-commit install` to enable style checks and execute them with `pre-commit run --all-files`. +4. Ensure the test suite passes with `poetry run pytest`. +5. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. ## Version Consistency diff --git a/LICENCE b/LICENSE similarity index 100% rename from LICENCE rename to LICENSE diff --git a/README.md b/README.md index 8e93f36..f486fa2 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,23 @@ To develop locally with all extras: ```bash git clone https://github.com/DiogoRibeiro7/genSurvPy.git cd genSurvPy -poetry install +# Install runtime and development dependencies +# (scikit-survival is optional but required for integration tests). +# On Debian/Ubuntu you may need ``build-essential gfortran libopenblas-dev`` to +# build scikit-survival. +poetry install --with dev +``` + +Integration tests that rely on scikit-survival are automatically skipped if the +package is not installed. + +## Development Setup + +Before committing changes, install the pre-commit hooks: + +```bash +pre-commit install +pre-commit run --all-files ``` ## Quick Example @@ -108,7 +124,7 @@ Open `build/html/index.html` in your browser to view the result. ## License -This project is licensed under the MIT License. See [LICENCE](LICENCE) for details. +This project is licensed under the MIT License. See [LICENSE](LICENSE) for details. ## Citation diff --git a/docs/source/api/index.md b/docs/source/api/index.md index a15a79e..54b7979 100644 --- a/docs/source/api/index.md +++ b/docs/source/api/index.md @@ -6,6 +6,12 @@ orphan: true Complete documentation for all gen_surv functions and classes. +```{note} +The `to_sksurv` helper and related tests rely on the optional +dependency `scikit-survival`. Install it with `poetry install --with dev` +or `pip install scikit-survival` to enable this functionality. +``` + ## Core Interface ```{eval-rst} diff --git a/docs/source/conf.py b/docs/source/conf.py index 2b0d9fd..f9162b6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,11 +7,11 @@ sys.path.insert(0, str(project_root / "gen_surv")) # Project information -project = 'gen_surv' -copyright = '2025, Diogo Ribeiro' -author = 'Diogo Ribeiro' -release = '1.0.8' -version = '1.0.8' +project = "gen_surv" +copyright = "2025, Diogo Ribeiro" +author = "Diogo Ribeiro" +release = "1.0.9" +version = "1.0.9" # General configuration extensions = [ @@ -41,11 +41,11 @@ # Autodoc configuration autodoc_default_options = { - 'members': True, - 'member-order': 'bysource', - 'special-members': '__init__', - 'undoc-members': True, - 'exclude-members': '__weakref__' + "members": True, + "member-order": "bysource", + "special-members": "__init__", + "undoc-members": True, + "exclude-members": "__weakref__", } # Autosummary @@ -59,9 +59,9 @@ # Intersphinx mapping intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'pandas': ('https://pandas.pydata.org/docs/', None), + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), } # Disable fetching remote inventories when network access is unavailable @@ -69,26 +69,26 @@ intersphinx_mapping = {} # HTML theme options -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_options = { - 'canonical_url': 'https://gensurvpy.readthedocs.io/', - 'analytics_id': '', - 'logo_only': False, - 'prev_next_buttons_location': 'bottom', - 'style_external_links': False, - 'style_nav_header_background': '#2980B9', - 'collapse_navigation': True, - 'sticky_navigation': True, - 'navigation_depth': 4, - 'includehidden': True, - 'titles_only': False + "canonical_url": "https://gensurvpy.readthedocs.io/", + "analytics_id": "", + "logo_only": False, + "prev_next_buttons_location": "bottom", + "style_external_links": False, + "style_nav_header_background": "#2980B9", + "collapse_navigation": True, + "sticky_navigation": True, + "navigation_depth": 4, + "includehidden": True, + "titles_only": False, } -html_static_path = ['_static'] -html_css_files = ['custom.css'] +html_static_path = ["_static"] +html_css_files = ["custom.css"] # Output file base name for HTML help builder -htmlhelp_basename = 'gensurvdoc' +htmlhelp_basename = "gensurvdoc" # Copy button configuration copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md index 6e05f8f..79e5962 100644 --- a/docs/source/examples/index.md +++ b/docs/source/examples/index.md @@ -6,6 +6,12 @@ orphan: true Real-world examples and use cases for gen_surv. +```{note} +Some examples may require optional packages such as +`scikit-survival`. Install them with `poetry install --with dev` or +`pip install scikit-survival` before running these examples. +``` + ```{toctree} :maxdepth: 2 diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 2ec1a03..7788093 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -22,6 +22,13 @@ cd genSurvPy poetry install ``` +```{note} +Some features and tests rely on optional packages such as +`scikit-survival`. Install them with `poetry install --with dev` or +`pip install scikit-survival` (additional system libraries may be +required). +``` + ## Basic Usage The main entry point is the `generate()` function: diff --git a/docs/source/index.md b/docs/source/index.md index d839e4e..8bd885f 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -26,6 +26,12 @@ df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0) ``` ``` +```{note} +The `to_sksurv` helper and related tests require the optional +dependency `scikit-survival`. Install it with `poetry install --with dev` +or `pip install scikit-survival` if you need this functionality. +``` + ## Supported Models | Model | Description | Use Case | @@ -123,13 +129,13 @@ If you use gen_surv in your research, please cite: author = {Diogo Ribeiro}, year = {2025}, url = {https://github.com/DiogoRibeiro7/genSurvPy}, - version = {1.0.8} + version = {1.0.9} } ``` ## License -MIT License - see [LICENSE](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENCE) for details. +MIT License - see [LICENSE](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENSE) for details. For foundational papers related to these models see the {doc}`bibliography`. Information on building the docs is provided in the {doc}`rtd` page. diff --git a/docs/source/usage.md b/docs/source/usage.md index b567d8a..69afc64 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -75,6 +75,12 @@ Datasets generated with **gen_surv** can be directly used with [scikit-survival](https://scikit-survival.readthedocs.io) you can convert the DataFrame using ``to_sksurv``: +```{note} +The ``to_sksurv`` helper requires the optional dependency +``scikit-survival``. Install it with `poetry install --with dev` or +``pip install scikit-survival``. +``` + ```python from gen_surv import to_sksurv diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index cbaa5cf..72ef4e1 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -17,12 +17,12 @@ from .cphm import gen_cphm from .export import export_dataset from .integration import to_sksurv -from .sklearn_adapter import GenSurvDataGenerator # Main interface from .interface import generate from .mixture import cure_fraction_estimate, gen_mixture_cure from .piecewise import gen_piecewise_exponential +from .sklearn_adapter import GenSurvDataGenerator from .tdcm import gen_tdcm from .thmm import gen_thmm diff --git a/gen_surv/integration.py b/gen_surv/integration.py index 0c7994c..720dfe2 100644 --- a/gen_surv/integration.py +++ b/gen_surv/integration.py @@ -29,8 +29,6 @@ def to_sksurv(df: pd.DataFrame, time_col: str = "time", event_col: str = "status try: from sksurv.util import Surv except ImportError as exc: # pragma: no cover - optional dependency - raise ImportError( - "scikit-survival is required for this feature." - ) from exc + raise ImportError("scikit-survival is required for this feature.") from exc return Surv.from_dataframe(event_col, time_col, df) diff --git a/gen_surv/sklearn_adapter.py b/gen_surv/sklearn_adapter.py index f8f8141..ca2718c 100644 --- a/gen_surv/sklearn_adapter.py +++ b/gen_surv/sklearn_adapter.py @@ -7,6 +7,7 @@ try: # pragma: no cover - only imported if sklearn is installed from sklearn.base import BaseEstimator except Exception: # pragma: no cover - fallback when sklearn missing + class BaseEstimator: # type: ignore """Minimal stub if scikit-learn is not installed.""" @@ -19,7 +20,9 @@ def __init__(self, model: str, return_type: str = "df", **kwargs: Any) -> None: self.return_type = return_type self.kwargs = kwargs - def fit(self, X: Optional[Any] = None, y: Optional[Any] = None) -> "GenSurvDataGenerator": + def fit( + self, X: Optional[Any] = None, y: Optional[Any] = None + ) -> "GenSurvDataGenerator": return self def transform(self, X: Optional[Any] = None) -> Any: @@ -30,5 +33,7 @@ def transform(self, X: Optional[Any] = None) -> Any: return df.to_dict(orient="list") raise ValueError("return_type must be 'df' or 'dict'") - def fit_transform(self, X: Optional[Any] = None, y: Optional[Any] = None, **fit_params: Any) -> Any: + def fit_transform( + self, X: Optional[Any] = None, y: Optional[Any] = None, **fit_params: Any + ) -> Any: return self.fit(X, y).transform(X) diff --git a/pyproject.toml b/pyproject.toml index 89354e7..8a268db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gen_surv" -version = "1.0.8" +version = "1.0.9" description = "A Python package for simulating survival data, inspired by the R package genSurv" authors = ["Diogo Ribeiro "] license = "MIT" @@ -44,6 +44,7 @@ tomli = "^2.2.1" black = "^24.1.0" isort = "^5.13.2" flake8 = "^6.1.0" +scikit-survival = "^0.24.1" [tool.poetry.group.docs.dependencies] sphinx = ">=6.0" diff --git a/scripts/check_version_match.py b/scripts/check_version_match.py index 9c3b7dd..316c229 100755 --- a/scripts/check_version_match.py +++ b/scripts/check_version_match.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 """Check that pyproject version matches the latest git tag. Optionally fix it by tagging.""" -from pathlib import Path import subprocess import sys +from pathlib import Path if sys.version_info >= (3, 11): import tomllib as tomli @@ -11,12 +11,14 @@ ROOT = Path(__file__).resolve().parents[1] + def pyproject_version() -> str: pyproject_path = ROOT / "pyproject.toml" with pyproject_path.open("rb") as f: data = tomli.load(f) return data["tool"]["poetry"]["version"] + def latest_tag() -> str: try: tag = subprocess.check_output( @@ -26,12 +28,14 @@ def latest_tag() -> str: except subprocess.CalledProcessError: return "" + def create_tag(version: str) -> None: print(f"Tagging repository with version: v{version}") subprocess.run(["git", "tag", f"v{version}"], cwd=ROOT, check=True) subprocess.run(["git", "push", "origin", f"v{version}"], cwd=ROOT, check=True) print(f"✅ Git tag v{version} created and pushed.") + def main() -> int: fix = "--fix" in sys.argv version = pyproject_version() @@ -58,5 +62,6 @@ def main() -> int: print(f"✔️ Version matches latest tag: {version}") return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py index 7a457a8..5012f1e 100644 --- a/tests/test_cli_integration.py +++ b/tests/test_cli_integration.py @@ -1,5 +1,5 @@ -from typer.testing import CliRunner import pandas as pd +from typer.testing import CliRunner from gen_surv.cli import app @@ -8,18 +8,21 @@ def test_dataset_cli_integration(tmp_path): """Run dataset command end-to-end and verify CSV output.""" runner = CliRunner() out_file = tmp_path / "data.csv" - result = runner.invoke(app, [ - "dataset", - "cphm", - "--n", - "3", - "--beta", - "0.5", - "--covariate-range", - "1.0", - "-o", - str(out_file), - ]) + result = runner.invoke( + app, + [ + "dataset", + "cphm", + "--n", + "3", + "--beta", + "0.5", + "--covariate-range", + "1.0", + "-o", + str(out_file), + ], + ) assert result.exit_code == 0 assert out_file.exists() df = pd.read_csv(out_file) diff --git a/tests/test_integration_sksurv.py b/tests/test_integration_sksurv.py index 7f6c2e8..84697d7 100644 --- a/tests/test_integration_sksurv.py +++ b/tests/test_integration_sksurv.py @@ -1,9 +1,12 @@ import pandas as pd +import pytest from gen_surv.integration import to_sksurv def test_to_sksurv(): + # Optional integration test; skipped when scikit-survival is not installed. + pytest.importorskip("sksurv.util") df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) arr = to_sksurv(df) assert arr.dtype.names == ("status", "time") diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py index 7dba747..61b75d3 100644 --- a/tests/test_piecewise.py +++ b/tests/test_piecewise.py @@ -94,7 +94,6 @@ def test_piecewise_invalid_censoring_model(): ) - def test_piecewise_negative_breakpoint(): with pytest.raises(ValueError): gen_piecewise_exponential( @@ -103,4 +102,3 @@ def test_piecewise_negative_breakpoint(): hazard_rates=[0.5, 1.0], seed=1, ) - From f92a584e8aa05bc9aab684e423231dab45c434e2 Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Sat, 2 Aug 2025 16:38:59 +0100 Subject: [PATCH 19/19] refactor: unify error handling (#60) --- benchmarks/README.md | 13 + benchmarks/test_tdcm_benchmark.py | 19 + benchmarks/test_validation_benchmark.py | 11 + binder/requirements.txt | 2 + docs/requirements.txt | 1 + docs/source/conf.py | 1 + docs/source/examples/cmm.md | 28 ++ docs/source/examples/index.md | 7 + docs/source/examples/tdcm.md | 30 ++ docs/source/examples/thmm.md | 28 ++ examples/notebooks/cmm.ipynb | 45 +++ examples/notebooks/tdcm.ipynb | 47 +++ examples/notebooks/thmm.ipynb | 45 +++ gen_surv/__init__.py | 19 +- gen_surv/_validation.py | 123 ++++++ gen_surv/aft.py | 44 +- gen_surv/bivariate.py | 68 ++-- gen_surv/censoring.py | 77 +++- gen_surv/cli.py | 8 +- gen_surv/cmm.py | 61 ++- gen_surv/competing_risks.py | 103 +++-- gen_surv/cphm.py | 19 +- gen_surv/export.py | 8 +- gen_surv/integration.py | 6 +- gen_surv/interface.py | 20 +- gen_surv/mixture.py | 260 +++++++----- gen_surv/piecewise.py | 50 +-- gen_surv/sklearn_adapter.py | 12 +- gen_surv/summary.py | 48 ++- gen_surv/tdcm.py | 79 ++-- gen_surv/thmm.py | 29 +- gen_surv/validate.py | 515 +++++++++++++----------- gen_surv/visualization.py | 18 +- tests/test_aft.py | 13 +- tests/test_bivariate.py | 13 +- tests/test_censoring.py | 56 ++- tests/test_competing_risks.py | 15 +- tests/test_summary_more.py | 7 +- tests/test_validate.py | 31 ++ 39 files changed, 1346 insertions(+), 633 deletions(-) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/test_tdcm_benchmark.py create mode 100644 benchmarks/test_validation_benchmark.py create mode 100644 binder/requirements.txt create mode 100644 docs/source/examples/cmm.md create mode 100644 docs/source/examples/tdcm.md create mode 100644 docs/source/examples/thmm.md create mode 100644 examples/notebooks/cmm.ipynb create mode 100644 examples/notebooks/tdcm.ipynb create mode 100644 examples/notebooks/thmm.ipynb create mode 100644 gen_surv/_validation.py diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..37e9c19 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,13 @@ +# Benchmarks + +This directory contains performance benchmarks run with `pytest-benchmark`. +Run them with: + +```bash +pytest benchmarks -q --benchmark-only +``` + +## Available benchmarks + +- validation helpers +- TDCM generation diff --git a/benchmarks/test_tdcm_benchmark.py b/benchmarks/test_tdcm_benchmark.py new file mode 100644 index 0000000..9509548 --- /dev/null +++ b/benchmarks/test_tdcm_benchmark.py @@ -0,0 +1,19 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.tdcm import gen_tdcm + + +def test_tdcm_generation_benchmark(benchmark): + benchmark( + gen_tdcm, + n=1000, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) diff --git a/benchmarks/test_validation_benchmark.py b/benchmarks/test_validation_benchmark.py new file mode 100644 index 0000000..ad04dfc --- /dev/null +++ b/benchmarks/test_validation_benchmark.py @@ -0,0 +1,11 @@ +import numpy as np +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv._validation import ensure_positive_sequence + + +def test_positive_sequence_benchmark(benchmark): + seq = np.random.rand(10000) + 1.0 + benchmark(ensure_positive_sequence, seq, "seq") diff --git a/binder/requirements.txt b/binder/requirements.txt new file mode 100644 index 0000000..2df98f0 --- /dev/null +++ b/binder/requirements.txt @@ -0,0 +1,2 @@ +-e . +jupyterlab diff --git a/docs/requirements.txt b/docs/requirements.txt index 11d4769..a0990cf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ sphinx-autodoc-typehints sphinx-copybutton sphinx-design linkify-it-py>=2.0 +matplotlib diff --git a/docs/source/conf.py b/docs/source/conf.py index f9162b6..0b09aed 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,7 @@ "sphinx.ext.intersphinx", "sphinx.ext.autosummary", "sphinx.ext.githubpages", + "sphinx.ext.plot_directive", "myst_parser", "sphinx_copybutton", "sphinx_design", diff --git a/docs/source/examples/cmm.md b/docs/source/examples/cmm.md new file mode 100644 index 0000000..45593c7 --- /dev/null +++ b/docs/source/examples/cmm.md @@ -0,0 +1,28 @@ +# Continuous-Time Multi-State Markov Model (CMM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/cmm.ipynb) + +Visualize transition times from the CMM generator: + +```{plot} +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="cmm", + n=200, + model_cens="exponential", + cens_par=2.0, + beta=[0.1, 0.2, 0.3], + covariate_range=1.0, + rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0], +) + +plt.hist(df["stop"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("CMM Transition Times") +``` diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md index 79e5962..14003cf 100644 --- a/docs/source/examples/index.md +++ b/docs/source/examples/index.md @@ -12,7 +12,14 @@ Some examples may require optional packages such as `pip install scikit-survival` before running these examples. ``` +Each example page includes a [Binder](https://mybinder.org/) badge so you can +launch the corresponding notebook and experiment interactively in your +browser. + ```{toctree} :maxdepth: 2 +tdcm +cmm +thmm ``` diff --git a/docs/source/examples/tdcm.md b/docs/source/examples/tdcm.md new file mode 100644 index 0000000..01e3550 --- /dev/null +++ b/docs/source/examples/tdcm.md @@ -0,0 +1,30 @@ +# Time-Dependent Covariate Model (TDCM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/tdcm.ipynb) + +A basic visualization of event times produced by the TDCM generator: + +```{plot} +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="tdcm", + n=200, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, +) + +plt.hist(df["stop"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("TDCM Event Times") +``` diff --git a/docs/source/examples/thmm.md b/docs/source/examples/thmm.md new file mode 100644 index 0000000..c5d17d8 --- /dev/null +++ b/docs/source/examples/thmm.md @@ -0,0 +1,28 @@ +# Time-Homogeneous Hidden Markov Model (THMM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/thmm.ipynb) + +An example of event times generated by the THMM: + +```{plot} +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="thmm", + n=200, + model_cens="exponential", + cens_par=3.0, + beta=[0.1, 0.2, 0.3], + covariate_range=1.0, + rate=[0.2, 0.1, 0.3], +) + +plt.hist(df["time"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("THMM Event Times") +``` diff --git a/examples/notebooks/cmm.ipynb b/examples/notebooks/cmm.ipynb new file mode 100644 index 0000000..303b3ff --- /dev/null +++ b/examples/notebooks/cmm.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"cmm\",\n", + " n=200,\n", + " model_cens=\"exponential\",\n", + " cens_par=2.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " covariate_range=1.0,\n", + " rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0],\n", + ")\n", + "df.head()\n", + "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('CMM Transition Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/tdcm.ipynb b/examples/notebooks/tdcm.ipynb new file mode 100644 index 0000000..0f9a91f --- /dev/null +++ b/examples/notebooks/tdcm.ipynb @@ -0,0 +1,47 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"tdcm\",\n", + " n=200,\n", + " dist=\"weibull\",\n", + " corr=0.5,\n", + " dist_par=[1, 2, 1, 2],\n", + " model_cens=\"uniform\",\n", + " cens_par=1.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " lam=1.0,\n", + ")\n", + "df.head()\n", + "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('TDCM Event Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/thmm.ipynb b/examples/notebooks/thmm.ipynb new file mode 100644 index 0000000..c46463c --- /dev/null +++ b/examples/notebooks/thmm.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"thmm\",\n", + " n=200,\n", + " model_cens=\"exponential\",\n", + " cens_par=3.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " covariate_range=1.0,\n", + " rate=[0.2, 0.1, 0.3],\n", + ")\n", + "df.head()\n", + "plt.hist(df['time'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('THMM Event Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 72ef4e1..55e08e8 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -9,7 +9,17 @@ # Helper functions from .bivariate import sample_bivariate_distribution -from .censoring import rexpocens, runifcens +from .censoring import ( + CensoringModel, + rexpocens, + runifcens, + rweibcens, + rlognormcens, + rgammacens, + WeibullCensoring, + LogNormalCensoring, + GammaCensoring, +) from .cmm import gen_cmm from .competing_risks import gen_competing_risks, gen_competing_risks_weibull @@ -65,6 +75,13 @@ "sample_bivariate_distribution", "runifcens", "rexpocens", + "rweibcens", + "rlognormcens", + "rgammacens", + "WeibullCensoring", + "LogNormalCensoring", + "GammaCensoring", + "CensoringModel", "export_dataset", "to_sksurv", "GenSurvDataGenerator", diff --git a/gen_surv/_validation.py b/gen_surv/_validation.py new file mode 100644 index 0000000..d3d924d --- /dev/null +++ b/gen_surv/_validation.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Iterable + +import numpy as np +from numpy.typing import NDArray + + +class ValidationError(ValueError): + """Base class for input validation errors.""" + + +class PositiveIntegerError(ValidationError): + """Raised when a value expected to be a positive integer is invalid.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__(f"Argument '{name}' must be a positive integer; got {value!r}") + + +class PositiveValueError(ValidationError): + """Raised when a value expected to be positive is invalid.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__(f"Argument '{name}' must be greater than 0; got {value!r}") + + +class ChoiceError(ValidationError): + """Raised when a value is not among an allowed set of choices.""" + + def __init__(self, name: str, value: Any, choices: Iterable[str]) -> None: + choices_str = "', '".join(sorted(choices)) + super().__init__( + f"Argument '{name}' must be one of '{choices_str}'; got {value!r}" + ) + + +class LengthError(ValidationError): + """Raised when a sequence does not have the expected length.""" + + def __init__(self, name: str, actual: int, expected: int) -> None: + super().__init__( + f"Argument '{name}' must be a sequence of length {expected}; got length {actual}" + ) + + +class NumericSequenceError(ValidationError): + """Raised when a sequence contains non-numeric elements.""" + + def __init__(self, name: str, seq: Sequence[Any]) -> None: + super().__init__(f"All elements in '{name}' must be numeric; got {seq!r}") + + +class PositiveSequenceError(ValidationError): + """Raised when a sequence contains non-positive elements.""" + + def __init__(self, name: str, seq: Sequence[Any]) -> None: + super().__init__(f"All elements in '{name}' must be greater than 0; got {seq!r}") + + +class ListOfListsError(ValidationError): + """Raised when a value is not a list of lists.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__(f"Argument '{name}' must be a list of lists; got {value!r}") + + +class ParameterError(ValidationError): + """Raised when a parameter falls outside its allowed range.""" + + def __init__(self, name: str, value: Any, constraint: str) -> None: + super().__init__( + f"Invalid value for '{name}': {value!r}. {constraint}" + ) + + +_ALLOWED_CENSORING = {"uniform", "exponential"} + + +def ensure_positive_int(value: int, name: str) -> None: + """Ensure ``value`` is a positive integer.""" + if not isinstance(value, int) or value <= 0: + raise PositiveIntegerError(name, value) + + +def ensure_positive(value: float | int, name: str) -> None: + """Ensure ``value`` is a positive number.""" + if not isinstance(value, (int, float)) or value <= 0: + raise PositiveValueError(name, value) + + +def ensure_in_choices(value: str, name: str, choices: Iterable[str]) -> None: + """Ensure ``value`` is one of the given ``choices``.""" + if value not in choices: + raise ChoiceError(name, value, choices) + + +def ensure_sequence_length(seq: Sequence[Any], length: int, name: str) -> None: + """Ensure ``seq`` has the specified ``length``.""" + if len(seq) != length: + raise LengthError(name, len(seq), length) + + +def _to_float_array(seq: Sequence[Any], name: str) -> NDArray[np.float64]: + """Convert ``seq`` to a NumPy float64 array or raise an error.""" + try: + return np.asarray(seq, dtype=float) + except (TypeError, ValueError) as exc: + raise NumericSequenceError(name, seq) from exc + +def ensure_numeric_sequence(seq: Sequence[Any], name: str) -> None: + """Ensure all elements of ``seq`` are numeric.""" + _to_float_array(seq, name) + +def ensure_positive_sequence(seq: Sequence[float], name: str) -> None: + """Ensure all elements of ``seq`` are positive.""" + arr = _to_float_array(seq, name) + if np.any(arr <= 0): + raise PositiveSequenceError(name, seq) + +def ensure_censoring_model(model_cens: str) -> None: + """Validate that the censoring model is supported.""" + ensure_in_choices(model_cens, "model_cens", _ALLOWED_CENSORING) diff --git a/gen_surv/aft.py b/gen_surv/aft.py index 0c958ad..879e892 100644 --- a/gen_surv/aft.py +++ b/gen_surv/aft.py @@ -7,6 +7,9 @@ import numpy as np import pandas as pd +from ._validation import ensure_censoring_model, ensure_positive +from .censoring import rexpocens, runifcens + def gen_aft_log_normal( n: int, @@ -48,12 +51,9 @@ def gen_aft_log_normal( log_T = X @ np.array(beta) + epsilon T = np.exp(log_T) - if model_cens == "uniform": - C = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - C = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + ensure_censoring_model(model_cens) + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par) observed_time = np.minimum(T, C) status = (T <= C).astype(int) @@ -106,11 +106,8 @@ def gen_aft_weibull( if seed is not None: np.random.seed(seed) - if shape <= 0: - raise ValueError("shape parameter must be positive") - - if scale <= 0: - raise ValueError("scale parameter must be positive") + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") p = len(beta) X = np.random.normal(size=(n, p)) @@ -123,12 +120,9 @@ def gen_aft_weibull( T = scale * (-np.log(U) * np.exp(-eta)) ** (1 / shape) # Generate censoring times - if model_cens == "uniform": - C = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - C = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + ensure_censoring_model(model_cens) + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par) # Observed time is the minimum of event time and censoring time observed_time = np.minimum(T, C) @@ -184,11 +178,8 @@ def gen_aft_log_logistic( if seed is not None: np.random.seed(seed) - if shape <= 0: - raise ValueError("shape parameter must be positive") - - if scale <= 0: - raise ValueError("scale parameter must be positive") + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") p = len(beta) X = np.random.normal(size=(n, p)) @@ -210,12 +201,9 @@ def gen_aft_log_logistic( T = scale * (U / (1 - U)) ** (1 / shape) * np.exp(-eta / shape) # Generate censoring times - if model_cens == "uniform": - C = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - C = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + ensure_censoring_model(model_cens) + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par) # Observed time is the minimum of event time and censoring time observed_time = np.minimum(T, C) diff --git a/gen_surv/bivariate.py b/gen_surv/bivariate.py index 7bde484..7cef27b 100644 --- a/gen_surv/bivariate.py +++ b/gen_surv/bivariate.py @@ -1,45 +1,57 @@ import numpy as np - - -def sample_bivariate_distribution(n, dist, corr, dist_par): +from numpy.typing import NDArray +from typing import Sequence + +from .validate import validate_dg_biv_inputs + + +_CHI2_SCALE = 0.5 +_CLIP_EPS = 1e-10 + + +def sample_bivariate_distribution( + n: int, dist: str, corr: float, dist_par: Sequence[float] +) -> NDArray[np.float64]: + """Draw correlated samples from Weibull or exponential marginals. + + Parameters + ---------- + n : int + Number of samples to generate. + dist : {"weibull", "exponential"} + Type of marginal distributions. + corr : float + Correlation coefficient. + dist_par : Sequence[float] + Distribution parameters ``[a1, b1, a2, b2]`` for the Weibull case or + ``[lambda1, lambda2]`` for the exponential case. + + Returns + ------- + NDArray[np.float64] + Array of shape ``(n, 2)`` with the sampled pairs. + + Raises + ------ + ValidationError + If ``dist`` is unsupported or ``dist_par`` has an invalid length. """ - Generate samples from a bivariate distribution with specified correlation. - Parameters: - - n (int): Number of samples - - dist (str): 'weibull' or 'exponential' - - corr (float): Correlation coefficient between [-1, 1] - - dist_par (list): Parameters for the marginals - - Returns: - - np.ndarray of shape (n, 2) - """ - if dist not in {"weibull", "exponential"}: - raise ValueError( - "Only 'weibull' and 'exponential' distributions are supported." - ) + validate_dg_biv_inputs(n, dist, corr, dist_par) # Step 1: Generate correlated standard normals using Cholesky mean = [0, 0] cov = [[1, corr], [corr, 1]] z = np.random.multivariate_normal(mean, cov, size=n) - u = 1 - np.exp(-0.5 * z**2) # transform normals to uniform via chi-squared approx - u = np.clip(u, 1e-10, 1 - 1e-10) # avoid infs in tails + u = 1 - np.exp(-_CHI2_SCALE * z**2) # transform normals to uniform via chi-squared approx + u = np.clip(u, _CLIP_EPS, 1 - _CLIP_EPS) # avoid infs in tails # Step 2: Transform to marginals if dist == "exponential": - if len(dist_par) != 2: - raise ValueError( - "Exponential distribution requires 2 positive rate parameters." - ) x1 = -np.log(1 - u[:, 0]) / dist_par[0] x2 = -np.log(1 - u[:, 1]) / dist_par[1] - elif dist == "weibull": - if len(dist_par) != 4: - raise ValueError( - "Weibull distribution requires 4 positive parameters [a1, b1, a2, b2]." - ) + else: # dist == "weibull" a1, b1, a2, b2 = dist_par x1 = (-np.log(1 - u[:, 0]) / a1) ** (1 / b1) x2 = (-np.log(1 - u[:, 1]) / a2) ** (1 / b2) diff --git a/gen_surv/censoring.py b/gen_surv/censoring.py index 6d83067..efc17c8 100644 --- a/gen_surv/censoring.py +++ b/gen_surv/censoring.py @@ -1,7 +1,25 @@ import numpy as np +from numpy.typing import NDArray +from typing import Protocol -def runifcens(size: int, cens_par: float) -> np.ndarray: +class CensoringFunc(Protocol): + """Protocol for censoring time generators.""" + + def __call__(self, size: int, cens_par: float) -> NDArray[np.float64]: + """Generate ``size`` censoring times given ``cens_par``.""" + ... + + +class CensoringModel(Protocol): + """Protocol for class-based censoring generators.""" + + def __call__(self, size: int) -> NDArray[np.float64]: + """Generate ``size`` censoring times.""" + ... + + +def runifcens(size: int, cens_par: float) -> NDArray[np.float64]: """ Generate uniform censoring times. @@ -10,12 +28,12 @@ def runifcens(size: int, cens_par: float) -> np.ndarray: - cens_par (float): Upper bound for uniform distribution. Returns: - - np.ndarray of censoring times. + - NDArray of censoring times. """ return np.random.uniform(0, cens_par, size) -def rexpocens(size: int, cens_par: float) -> np.ndarray: +def rexpocens(size: int, cens_par: float) -> NDArray[np.float64]: """ Generate exponential censoring times. @@ -24,6 +42,57 @@ def rexpocens(size: int, cens_par: float) -> np.ndarray: - cens_par (float): Mean of exponential distribution. Returns: - - np.ndarray of censoring times. + - NDArray of censoring times. """ return np.random.exponential(scale=cens_par, size=size) + + +def rweibcens(size: int, scale: float, shape: float) -> NDArray[np.float64]: + """Generate Weibull-distributed censoring times.""" + return np.random.weibull(shape, size) * scale + + +def rlognormcens(size: int, mean: float, sigma: float) -> NDArray[np.float64]: + """Generate log-normal-distributed censoring times.""" + return np.random.lognormal(mean, sigma, size) + + +def rgammacens(size: int, shape: float, scale: float) -> NDArray[np.float64]: + """Generate Gamma-distributed censoring times.""" + return np.random.gamma(shape, scale, size) + + +class WeibullCensoring: + """Class-based generator for Weibull censoring times.""" + + def __init__(self, scale: float, shape: float) -> None: + self.scale = scale + self.shape = shape + + def __call__(self, size: int) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a Weibull distribution.""" + return np.random.weibull(self.shape, size) * self.scale + + +class LogNormalCensoring: + """Class-based generator for log-normal censoring times.""" + + def __init__(self, mean: float, sigma: float) -> None: + self.mean = mean + self.sigma = sigma + + def __call__(self, size: int) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a log-normal distribution.""" + return np.random.lognormal(self.mean, self.sigma, size) + + +class GammaCensoring: + """Class-based generator for Gamma censoring times.""" + + def __init__(self, shape: float, scale: float) -> None: + self.shape = shape + self.scale = scale + + def __call__(self, size: int) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a Gamma distribution.""" + return np.random.gamma(self.shape, self.scale, size) diff --git a/gen_surv/cli.py b/gen_surv/cli.py index cd06d87..93262db 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -5,7 +5,7 @@ using the gen_surv package. """ -from typing import List, Optional, Tuple +from typing import List, Optional, TypeVar, cast import typer @@ -85,8 +85,10 @@ def dataset( # Helper to unwrap Typer Option defaults when function is called directly from typer.models import OptionInfo - def _val(v): - return v if not isinstance(v, OptionInfo) else v.default + T = TypeVar("T") + + def _val(v: T | OptionInfo) -> T: + return v if not isinstance(v, OptionInfo) else cast(T, v.default) # Prepare arguments based on the selected model model_str = _val(model) diff --git a/gen_surv/cmm.py b/gen_surv/cmm.py index 983a351..7c9828f 100644 --- a/gen_surv/cmm.py +++ b/gen_surv/cmm.py @@ -1,11 +1,20 @@ import numpy as np import pandas as pd +from typing import Sequence, TypedDict -from gen_surv.censoring import rexpocens, runifcens +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens from gen_surv.validate import validate_gen_cmm_inputs -def generate_event_times(z1: float, beta: list, rate: list) -> dict: +class EventTimes(TypedDict): + t12: float + t13: float + t23: float + + +def generate_event_times( + z1: float, beta: Sequence[float], rate: Sequence[float] +) -> EventTimes: """ Generate event times for a continuous-time multi-state Markov model. @@ -29,7 +38,14 @@ def generate_event_times(z1: float, beta: list, rate: list) -> dict: return {"t12": t12, "t13": t13, "t23": t23} -def gen_cmm(n, model_cens, cens_par, beta, covariate_range, rate): +def gen_cmm( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], +) -> pd.DataFrame: """ Generate survival data using a continuous-time Markov model (CMM). @@ -46,28 +62,29 @@ def gen_cmm(n, model_cens, cens_par, beta, covariate_range, rate): """ validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) - rfunc = runifcens if model_cens == "uniform" else rexpocens - rows = [] + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens + + z1 = np.random.uniform(0, covariate_range, size=n) + c = rfunc(n, cens_par) - for k in range(n): - z1 = np.random.uniform(0, covariate_range) - c = rfunc(1, cens_par)[0] - events = generate_event_times(z1, beta, rate) + u = np.random.uniform(size=(3, n)) + t12 = (-np.log(1 - u[0]) / (rate[0] * np.exp(beta[0] * z1))) ** (1 / rate[1]) + t13 = (-np.log(1 - u[1]) / (rate[2] * np.exp(beta[1] * z1))) ** (1 / rate[3]) - t12, t13, t23 = events["t12"], events["t13"], events["t23"] - min_event_time = min(t12, t13, c) + first_event = np.minimum(t12, t13) + censored = first_event >= c - if min_event_time < c: - if t12 <= t13: - transition = 1 # 1 -> 2 - rows.append([k + 1, 0, t12, 1, z1, transition]) - else: - transition = 2 # 1 -> 3 - rows.append([k + 1, 0, t13, 1, z1, transition]) - else: - # Censored before any event - rows.append([k + 1, 0, c, 0, z1, np.nan]) + status = (~censored).astype(int) + transition = np.where(censored, np.nan, np.where(t12 <= t13, 1, 2)) + stop = np.where(censored, c, first_event) return pd.DataFrame( - rows, columns=["id", "start", "stop", "status", "X0", "transition"] + { + "id": np.arange(1, n + 1), + "start": np.zeros(n), + "stop": stop, + "status": status, + "X0": z1, + "transition": transition, + } ) diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py index d49a308..3c73276 100644 --- a/gen_surv/competing_risks.py +++ b/gen_surv/competing_risks.py @@ -10,6 +10,15 @@ import numpy as np import pandas as pd +from ._validation import ( + ensure_censoring_model, + ensure_in_choices, + ensure_positive_sequence, + ensure_sequence_length, + ParameterError, +) +from .censoring import rexpocens, runifcens + if TYPE_CHECKING: # pragma: no cover - used only for type hints from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -92,16 +101,17 @@ def gen_competing_risks( if baseline_hazards is None: baseline_hazards = np.array([0.5 / (i + 1) for i in range(n_risks)]) else: - baseline_hazards = np.array(baseline_hazards) - if len(baseline_hazards) != n_risks: - raise ValueError( - f"Expected {n_risks} baseline hazards, got {len(baseline_hazards)}" - ) + baseline_hazards = np.asarray(baseline_hazards, dtype=float) + ensure_sequence_length(baseline_hazards, n_risks, "baseline_hazards") + ensure_positive_sequence(baseline_hazards, "baseline_hazards") # Set default number of covariates and their parameters n_covariates = 2 # Default number of covariates # Set default covariate parameters if not provided + ensure_in_choices( + covariate_dist, "covariate_dist", {"normal", "uniform", "binary"} + ) if covariate_params is None: if covariate_dist == "normal": covariate_params = {"mean": 0.0, "std": 1.0} @@ -109,18 +119,13 @@ def gen_competing_risks( covariate_params = {"low": 0.0, "high": 1.0} elif covariate_dist == "binary": covariate_params = {"p": 0.5} - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") # Set default betas if not provided if betas is None: betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) else: - betas = np.array(betas) - if betas.shape[0] != n_risks: - raise ValueError( - f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}" - ) + betas = np.asarray(betas, dtype=float) + ensure_sequence_length(betas, n_risks, "betas") n_covariates = betas.shape[1] # Generate covariates @@ -140,8 +145,12 @@ def gen_competing_risks( X = np.random.binomial( 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + else: # pragma: no cover - validated above + raise ParameterError( + "covariate_dist", + covariate_dist, + "must be one of {'normal', 'uniform', 'binary'}", + ) # Calculate linear predictors for each risk linear_predictors = np.zeros((n, n_risks)) @@ -160,12 +169,9 @@ def gen_competing_risks( event_times[:, j] = np.random.exponential(1 / hazard_rates[:, j]) # Generate censoring times - if model_cens == "uniform": - cens_times = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - cens_times = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + ensure_censoring_model(model_cens) + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par) # Find the minimum time for each subject (first event or censoring) min_event_times = np.min(event_times, axis=1) @@ -290,25 +296,22 @@ def gen_competing_risks_weibull( if shape_params is None: shape_params = np.array([1.2 if i % 2 == 0 else 0.8 for i in range(n_risks)]) else: - shape_params = np.array(shape_params) - if len(shape_params) != n_risks: - raise ValueError( - f"Expected {n_risks} shape parameters, got {len(shape_params)}" - ) + shape_params = np.asarray(shape_params, dtype=float) + ensure_sequence_length(shape_params, n_risks, "shape_params") if scale_params is None: scale_params = np.array([2.0 + i for i in range(n_risks)]) else: - scale_params = np.array(scale_params) - if len(scale_params) != n_risks: - raise ValueError( - f"Expected {n_risks} scale parameters, got {len(scale_params)}" - ) + scale_params = np.asarray(scale_params, dtype=float) + ensure_sequence_length(scale_params, n_risks, "scale_params") # Set default number of covariates and their parameters n_covariates = 2 # Default number of covariates # Set default covariate parameters if not provided + ensure_in_choices( + covariate_dist, "covariate_dist", {"normal", "uniform", "binary"} + ) if covariate_params is None: if covariate_dist == "normal": covariate_params = {"mean": 0.0, "std": 1.0} @@ -316,18 +319,13 @@ def gen_competing_risks_weibull( covariate_params = {"low": 0.0, "high": 1.0} elif covariate_dist == "binary": covariate_params = {"p": 0.5} - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") # Set default betas if not provided if betas is None: betas = np.random.normal(0, 0.5, size=(n_risks, n_covariates)) else: - betas = np.array(betas) - if betas.shape[0] != n_risks: - raise ValueError( - f"Expected {n_risks} sets of coefficients, got {betas.shape[0]}" - ) + betas = np.asarray(betas, dtype=float) + ensure_sequence_length(betas, n_risks, "betas") n_covariates = betas.shape[1] # Generate covariates @@ -347,8 +345,12 @@ def gen_competing_risks_weibull( X = np.random.binomial( 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + else: # pragma: no cover - validated above + raise ParameterError( + "covariate_dist", + covariate_dist, + "must be one of {'normal', 'uniform', 'binary'}", + ) # Calculate linear predictors for each risk linear_predictors = np.zeros((n, n_risks)) @@ -370,12 +372,9 @@ def gen_competing_risks_weibull( event_times[:, j] = adjusted_scale * (-np.log(1 - u)) ** (1 / shape_params[j]) # Generate censoring times - if model_cens == "uniform": - cens_times = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - cens_times = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + ensure_censoring_model(model_cens) + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par) # Find the minimum time for each subject (first event or censoring) min_event_times = np.min(event_times, axis=1) @@ -447,8 +446,8 @@ def cause_specific_cumulative_incidence( # Validate the cause value unique_causes = set(data[status_col].unique()) - {0} # Exclude censoring if cause not in unique_causes: - raise ValueError( - f"Cause {cause} not found in the data. Available causes: {unique_causes}" + raise ParameterError( + "cause", cause, f"not found in the data. Available causes: {unique_causes}" ) # Sort data by time @@ -499,8 +498,8 @@ def competing_risks_summary( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - covariate_cols: Optional[List[str]] = None, -) -> Dict[str, Any]: + covariate_cols: list[str] | None = None, +) -> dict[str, Any]: """ Provide a summary of a competing risks dataset. @@ -612,12 +611,12 @@ def competing_risks_summary( def plot_cause_specific_hazards( data: pd.DataFrame, - time_points: Optional[np.ndarray] = None, + time_points: np.ndarray | None = None, time_col: str = "time", status_col: str = "status", bandwidth: float = 0.5, - figsize: Tuple[float, float] = (10, 6), -) -> Tuple["Figure", "Axes"]: + figsize: tuple[float, float] = (10, 6), +) -> tuple["Figure", "Axes"]: """ Plot cause-specific hazard functions. diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index ea2d492..4d03185 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -5,23 +5,24 @@ Cox Proportional Hazards Model with various censoring mechanisms. """ -from typing import Callable, Literal, Optional +from typing import Literal import numpy as np import pandas as pd +from numpy.typing import NDArray -from gen_surv.censoring import rexpocens, runifcens +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens from gen_surv.validate import validate_gen_cphm_inputs def generate_cphm_data( n: int, - rfunc: Callable[[int, float], np.ndarray], + rfunc: CensoringFunc, cens_par: float, beta: float, covariate_range: float, - seed: Optional[int] = None, -) -> np.ndarray: + seed: int | None = None, +) -> NDArray[np.float64]: """ Generate data from a Cox Proportional Hazards Model (CPHM). @@ -42,13 +43,13 @@ def generate_cphm_data( Returns ------- - np.ndarray - Array with shape (n, 3): [time, status, X0] + NDArray[np.float64] + Array with shape ``(n, 3)``: ``[time, status, X0]`` """ if seed is not None: np.random.seed(seed) - data = np.zeros((n, 3)) + data: NDArray[np.float64] = np.zeros((n, 3), dtype=float) for k in range(n): z = np.random.uniform(0, covariate_range) @@ -69,7 +70,7 @@ def gen_cphm( cens_par: float, beta: float, covariate_range: float, - seed: Optional[int] = None, + seed: int | None = None, ) -> pd.DataFrame: """ Generate survival data following a Cox Proportional Hazards Model. diff --git a/gen_surv/export.py b/gen_surv/export.py index c3751bb..56cfca4 100644 --- a/gen_surv/export.py +++ b/gen_surv/export.py @@ -12,6 +12,8 @@ import pandas as pd import pyreadr +from ._validation import ensure_in_choices + def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> None: """Save a DataFrame to disk. @@ -28,12 +30,14 @@ def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> No Raises ------ - ValueError + ChoiceError If the format is not one of the supported types. """ if fmt is None: fmt = os.path.splitext(path)[1].lstrip(".").lower() + ensure_in_choices(fmt, "fmt", {"csv", "json", "feather", "ft", "rds"}) + if fmt == "csv": df.to_csv(path, index=False) elif fmt == "json": @@ -42,5 +46,3 @@ def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> No df.reset_index(drop=True).to_feather(path) elif fmt == "rds": pyreadr.write_rds(path, df.reset_index(drop=True)) - else: - raise ValueError(f"Unsupported export format: {fmt}") diff --git a/gen_surv/integration.py b/gen_surv/integration.py index 720dfe2..6967086 100644 --- a/gen_surv/integration.py +++ b/gen_surv/integration.py @@ -1,9 +1,13 @@ from __future__ import annotations +import numpy as np import pandas as pd +from numpy.typing import NDArray -def to_sksurv(df: pd.DataFrame, time_col: str = "time", event_col: str = "status"): +def to_sksurv( + df: pd.DataFrame, time_col: str = "time", event_col: str = "status" +) -> NDArray[np.void]: """Convert a DataFrame to a scikit-survival structured array. Parameters diff --git a/gen_surv/interface.py b/gen_surv/interface.py index ca87292..0c3cd79 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -6,7 +6,7 @@ >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) """ -from typing import Any, Literal +from typing import Any, Literal, Protocol, Dict import pandas as pd @@ -18,6 +18,7 @@ from gen_surv.piecewise import gen_piecewise_exponential from gen_surv.tdcm import gen_tdcm from gen_surv.thmm import gen_thmm +from ._validation import ensure_in_choices # Type definitions for model names ModelType = Literal[ @@ -34,8 +35,13 @@ "piecewise_exponential", ] +# Interface for generator callables +class DataGenerator(Protocol): + def __call__(self, **kwargs: Any) -> pd.DataFrame: ... + + # Map model names to their generator functions -_model_map = { +_model_map: Dict[ModelType, DataGenerator] = { "cphm": gen_cphm, "cmm": gen_cmm, "tdcm": gen_tdcm, @@ -50,7 +56,7 @@ } -def generate(model: str, **kwargs: Any) -> pd.DataFrame: +def generate(model: ModelType, **kwargs: Any) -> pd.DataFrame: """Generate survival data from a specific model. Args: @@ -78,12 +84,8 @@ def generate(model: str, **kwargs: Any) -> pd.DataFrame: All models include time/duration and status columns. Raises: - ValueError: If an unknown model name is provided. + ChoiceError: If an unknown model name is provided. """ model_lower = model.lower() - if model_lower not in _model_map: - valid_models = list(_model_map.keys()) - raise ValueError(f"Unknown model '{model}'. Choose from {valid_models}.") - - # Call the appropriate generator function with the provided kwargs + ensure_in_choices(model_lower, "model", _model_map.keys()) return _model_map[model_lower](**kwargs) diff --git a/gen_surv/mixture.py b/gen_surv/mixture.py index ab6770a..913912e 100644 --- a/gen_surv/mixture.py +++ b/gen_surv/mixture.py @@ -5,25 +5,148 @@ i.e., a proportion of subjects who are immune to the event of interest. """ -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Literal import numpy as np import pandas as pd +from numpy.typing import NDArray + + +_TAIL_FRACTION = 0.1 +_SMOOTH_MIN_TAIL = 3 + +from ._validation import ( + ensure_censoring_model, + ensure_in_choices, + ensure_positive, + LengthError, + ParameterError, +) +from .censoring import rexpocens, runifcens + + +def _set_covariate_params( + covariate_dist: str, + covariate_params: dict[str, float | tuple[float, float]] | None, +) -> dict[str, float | tuple[float, float]]: + if covariate_params is not None: + return covariate_params + if covariate_dist == "normal": + return {"mean": 0.0, "std": 1.0} + if covariate_dist == "uniform": + return {"low": 0.0, "high": 1.0} + if covariate_dist == "binary": + return {"p": 0.5} + raise ParameterError( + "covariate_dist", covariate_dist, "must be one of {'normal','uniform','binary'}" + ) + + +def _prepare_betas( + betas_survival: list[float] | None, + betas_cure: list[float] | None, + n_covariates: int, +) -> tuple[NDArray[np.float64], NDArray[np.float64], int]: + if betas_survival is None: + betas_survival_arr = np.random.normal(0, 0.5, size=n_covariates) + else: + betas_survival_arr = np.asarray(betas_survival, dtype=float) + n_covariates = len(betas_survival_arr) + + if betas_cure is None: + betas_cure_arr = np.random.normal(0, 0.5, size=n_covariates) + else: + betas_cure_arr = np.asarray(betas_cure, dtype=float) + if len(betas_cure_arr) != n_covariates: + raise LengthError("betas_cure", len(betas_cure_arr), n_covariates) + + return betas_survival_arr, betas_cure_arr, n_covariates + + +def _generate_covariates( + n: int, + n_covariates: int, + covariate_dist: str, + covariate_params: dict[str, float | tuple[float, float]], +) -> NDArray[np.float64]: + if covariate_dist == "normal": + return np.random.normal( + covariate_params.get("mean", 0.0), + covariate_params.get("std", 1.0), + size=(n, n_covariates), + ) + if covariate_dist == "uniform": + return np.random.uniform( + covariate_params.get("low", 0.0), + covariate_params.get("high", 1.0), + size=(n, n_covariates), + ) + if covariate_dist == "binary": + return np.random.binomial( + 1, covariate_params.get("p", 0.5), size=(n, n_covariates) + ).astype(float) + raise ParameterError( + "covariate_dist", covariate_dist, "must be one of {'normal','uniform','binary'}" + ) + + +def _cure_status( + lp_cure: NDArray[np.float64], cure_fraction: float +) -> NDArray[np.int64]: + cure_probs = 1 / ( + 1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)) + ) + return np.random.binomial(1, cure_probs).astype(np.int64) + + +def _survival_times( + cured: NDArray[np.int64], + lp_survival: NDArray[np.float64], + baseline_hazard: float, + max_time: float | None, +) -> NDArray[np.float64]: + n = cured.size + times = np.zeros(n, dtype=float) + non_cured = cured == 0 + adjusted_hazard = baseline_hazard * np.exp(lp_survival[non_cured]) + times[non_cured] = np.random.exponential(scale=1 / adjusted_hazard) + if max_time is not None: + times[~non_cured] = max_time * 100 + else: + times[~non_cured] = np.inf + return times + + +def _apply_censoring( + survival_times: NDArray[np.float64], + model_cens: str, + cens_par: float, + max_time: float | None, +) -> tuple[NDArray[np.float64], NDArray[np.int64]]: + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(len(survival_times), cens_par) + observed = np.minimum(survival_times, cens_times) + status = (survival_times <= cens_times).astype(int) + if max_time is not None: + over_max = observed > max_time + observed[over_max] = max_time + status[over_max] = 0 + return observed, status def gen_mixture_cure( n: int, cure_fraction: float, baseline_hazard: float = 0.5, - betas_survival: Optional[List[float]] = None, - betas_cure: Optional[List[float]] = None, + betas_survival: list[float] | None = None, + betas_cure: list[float] | None = None, n_covariates: int = 2, covariate_dist: Literal["normal", "uniform", "binary"] = "normal", - covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None, + covariate_params: dict[str, float | tuple[float, float]] | None = None, model_cens: Literal["uniform", "exponential"] = "uniform", cens_par: float = 5.0, - max_time: Optional[float] = 10.0, - seed: Optional[int] = None, + max_time: float | None = 10.0, + seed: int | None = None, ) -> pd.DataFrame: """ Generate survival data with a cure fraction using a mixture cure model. @@ -91,115 +214,34 @@ def gen_mixture_cure( if seed is not None: np.random.seed(seed) - # Validate inputs if not 0 <= cure_fraction <= 1: - raise ValueError("cure_fraction must be between 0 and 1") - - if baseline_hazard <= 0: - raise ValueError("baseline_hazard must be positive") - - # Set default covariate parameters if not provided - if covariate_params is None: - if covariate_dist == "normal": - covariate_params = {"mean": 0.0, "std": 1.0} - elif covariate_dist == "uniform": - covariate_params = {"low": 0.0, "high": 1.0} - elif covariate_dist == "binary": - covariate_params = {"p": 0.5} - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") - - # Set default betas if not provided - if betas_survival is None: - betas_survival = np.random.normal(0, 0.5, size=n_covariates) - else: - betas_survival = np.array(betas_survival) - n_covariates = len(betas_survival) - - if betas_cure is None: - betas_cure = np.random.normal(0, 0.5, size=n_covariates) - else: - betas_cure = np.array(betas_cure) - if len(betas_cure) != n_covariates: - raise ValueError( - f"betas_cure must have the same length as betas_survival, " - f"got {len(betas_cure)} vs {n_covariates}" - ) - - # Generate covariates - if covariate_dist == "normal": - X = np.random.normal( - covariate_params.get("mean", 0.0), - covariate_params.get("std", 1.0), - size=(n, n_covariates), + raise ParameterError( + "cure_fraction", cure_fraction, "must be between 0 and 1" ) - elif covariate_dist == "uniform": - X = np.random.uniform( - covariate_params.get("low", 0.0), - covariate_params.get("high", 1.0), - size=(n, n_covariates), - ) - elif covariate_dist == "binary": - X = np.random.binomial( - 1, covariate_params.get("p", 0.5), size=(n, n_covariates) - ) - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + ensure_positive(baseline_hazard, "baseline_hazard") - # Calculate linear predictors - lp_survival = X @ betas_survival - lp_cure = X @ betas_cure - - # Determine cure status (logistic model) - cure_probs = 1 / ( - 1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)) + ensure_in_choices( + covariate_dist, "covariate_dist", {"normal", "uniform", "binary"} + ) + covariate_params = _set_covariate_params(covariate_dist, covariate_params) + betas_survival_arr, betas_cure_arr, n_covariates = _prepare_betas( + betas_survival, betas_cure, n_covariates + ) + X = _generate_covariates(n, n_covariates, covariate_dist, covariate_params) + lp_survival = X @ betas_survival_arr + lp_cure = X @ betas_cure_arr + cured = _cure_status(lp_cure, cure_fraction) + survival_times = _survival_times(cured, lp_survival, baseline_hazard, max_time) + + ensure_censoring_model(model_cens) + observed_times, status = _apply_censoring( + survival_times, model_cens, cens_par, max_time ) - cured = np.random.binomial(1, cure_probs) - - # Generate survival times - survival_times = np.zeros(n) - - # For non-cured subjects, generate event times - non_cured_indices = np.where(cured == 0)[0] - - for i in non_cured_indices: - # Adjust hazard rate by covariate effect - adjusted_hazard = baseline_hazard * np.exp(lp_survival[i]) - - # Generate exponential survival time - survival_times[i] = np.random.exponential(scale=1 / adjusted_hazard) - - # For cured subjects, set "infinite" survival time - cured_indices = np.where(cured == 1)[0] - if max_time is not None: - survival_times[cured_indices] = max_time * 100 # Effectively infinite - else: - survival_times[cured_indices] = np.inf # Actually infinite - - # Generate censoring times - if model_cens == "uniform": - cens_times = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - cens_times = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") - - # Determine observed time and status - observed_times = np.minimum(survival_times, cens_times) - status = (survival_times <= cens_times).astype(int) - - # Cap times at max_time if specified - if max_time is not None: - over_max = observed_times > max_time - observed_times[over_max] = max_time - status[over_max] = 0 # Censored if beyond max_time - # Create DataFrame data = pd.DataFrame( {"id": np.arange(n), "time": observed_times, "status": status, "cured": cured} ) - # Add covariates for j in range(n_covariates): data[f"X{j}"] = X[:, j] @@ -262,12 +304,12 @@ def cure_fraction_estimate( survival[i] *= 1 - 1 / at_risk # Estimate cure fraction as the plateau of the survival curve - # Use the last 10% of the survival curve if enough data points - tail_size = max(int(n * 0.1), 1) + # Use the last portion of the survival curve if enough data points + tail_size = max(int(n * _TAIL_FRACTION), 1) tail_survival = survival[-tail_size:] # Apply smoothing if there are enough data points - if tail_size > 3: + if tail_size > _SMOOTH_MIN_TAIL: # Use kernel smoothing weights = np.exp( -((np.arange(tail_size) - tail_size + 1) ** 2) diff --git a/gen_surv/piecewise.py b/gen_surv/piecewise.py index 63b2412..5441237 100644 --- a/gen_surv/piecewise.py +++ b/gen_surv/piecewise.py @@ -10,6 +10,15 @@ import numpy as np import pandas as pd +from ._validation import ( + ensure_censoring_model, + ensure_in_choices, + ensure_positive_sequence, + ensure_sequence_length, + ParameterError, +) +from .censoring import rexpocens, runifcens + def gen_piecewise_exponential( n: int, @@ -80,21 +89,16 @@ def gen_piecewise_exponential( np.random.seed(seed) # Validate inputs - if len(hazard_rates) != len(breakpoints) + 1: - raise ValueError( - f"Expected {len(breakpoints) + 1} hazard rates, got {len(hazard_rates)}" - ) - - if not all(b > 0 for b in breakpoints): - raise ValueError("All breakpoints must be positive") - - if not all(h > 0 for h in hazard_rates): - raise ValueError("All hazard rates must be positive") + ensure_sequence_length(hazard_rates, len(breakpoints) + 1, "hazard_rates") + ensure_positive_sequence(breakpoints, "breakpoints") + ensure_positive_sequence(hazard_rates, "hazard_rates") + if np.any(np.diff(breakpoints) <= 0): + raise ParameterError("breakpoints", breakpoints, "must be in ascending order") - if not all( - breakpoints[i] < breakpoints[i + 1] for i in range(len(breakpoints) - 1) - ): - raise ValueError("Breakpoints must be in ascending order") + ensure_censoring_model(model_cens) + ensure_in_choices( + covariate_dist, "covariate_dist", {"normal", "uniform", "binary"} + ) # Set default covariate parameters if not provided if covariate_params is None: @@ -104,8 +108,6 @@ def gen_piecewise_exponential( covariate_params = {"low": 0.0, "high": 1.0} elif covariate_dist == "binary": covariate_params = {"p": 0.5} - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") # Set default betas if not provided if betas is None: @@ -131,8 +133,12 @@ def gen_piecewise_exponential( X = np.random.binomial( 1, covariate_params.get("p", 0.5), size=(n, n_covariates) ) - else: - raise ValueError(f"Unknown covariate distribution: {covariate_dist}") + else: # pragma: no cover - validated above + raise ParameterError( + "covariate_dist", + covariate_dist, + "must be one of {'normal', 'uniform', 'binary'}", + ) # Calculate linear predictor linear_predictor = X @ betas @@ -187,12 +193,8 @@ def gen_piecewise_exponential( survival_times[i] = total_time + remaining_time / hazard # Generate censoring times - if model_cens == "uniform": - cens_times = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - cens_times = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par) # Determine observed time and status observed_times = np.minimum(survival_times, cens_times) diff --git a/gen_surv/sklearn_adapter.py b/gen_surv/sklearn_adapter.py index ca2718c..275c401 100644 --- a/gen_surv/sklearn_adapter.py +++ b/gen_surv/sklearn_adapter.py @@ -2,7 +2,10 @@ from typing import Any, Optional +import pandas as pd + from .interface import generate +from ._validation import ensure_in_choices try: # pragma: no cover - only imported if sklearn is installed from sklearn.base import BaseEstimator @@ -25,15 +28,18 @@ def fit( ) -> "GenSurvDataGenerator": return self - def transform(self, X: Optional[Any] = None) -> Any: + def transform( + self, X: Optional[Any] = None + ) -> pd.DataFrame | dict[str, list[Any]]: df = generate(self.model, **self.kwargs) + ensure_in_choices(self.return_type, "return_type", {"df", "dict"}) if self.return_type == "df": return df if self.return_type == "dict": return df.to_dict(orient="list") - raise ValueError("return_type must be 'df' or 'dict'") + raise AssertionError("Unreachable due to validation") def fit_transform( self, X: Optional[Any] = None, y: Optional[Any] = None, **fit_params: Any - ) -> Any: + ) -> pd.DataFrame | dict[str, list[Any]]: return self.fit(X, y).transform(X) diff --git a/gen_surv/summary.py b/gen_surv/summary.py index dfd094d..49ad8cc 100644 --- a/gen_surv/summary.py +++ b/gen_surv/summary.py @@ -5,19 +5,21 @@ check data quality, and identify potential issues. """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import pandas as pd +from ._validation import ParameterError + def summarize_survival_dataset( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - id_col: Optional[str] = None, - covariate_cols: Optional[List[str]] = None, + id_col: str | None = None, + covariate_cols: list[str] | None = None, verbose: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generate a comprehensive summary of a survival dataset. @@ -39,7 +41,7 @@ def summarize_survival_dataset( Returns ------- - Dict[str, Any] + dict[str, Any] Dictionary containing all summary statistics. Examples @@ -57,10 +59,10 @@ def summarize_survival_dataset( # Validate input columns for col in [time_col, status_col]: if col not in data.columns: - raise ValueError(f"Column '{col}' not found in data") + raise ParameterError("column", col, "not found in data") if id_col is not None and id_col not in data.columns: - raise ValueError(f"ID column '{id_col}' not found in data") + raise ParameterError("id_col", id_col, "not found in data") # Determine covariate columns if covariate_cols is None: @@ -71,7 +73,9 @@ def summarize_survival_dataset( else: missing_cols = [col for col in covariate_cols if col not in data.columns] if missing_cols: - raise ValueError(f"Covariate columns not found in data: {missing_cols}") + raise ParameterError( + "covariate_cols", missing_cols, "not found in data" + ) # Basic dataset information n_subjects = len(data) @@ -172,12 +176,12 @@ def check_survival_data_quality( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - id_col: Optional[str] = None, + id_col: str | None = None, min_time: float = 0.0, - max_time: Optional[float] = None, - status_values: Optional[List[int]] = None, + max_time: float | None = None, + status_values: list[int] | None = None, fix_issues: bool = False, -) -> Tuple[pd.DataFrame, Dict[str, Any]]: +) -> tuple[pd.DataFrame, dict[str, Any]]: """ Check for common issues in survival data and optionally fix them. @@ -202,7 +206,7 @@ def check_survival_data_quality( Returns ------- - Tuple[pd.DataFrame, Dict[str, Any]] + tuple[pd.DataFrame, dict[str, Any]] Tuple containing (possibly fixed) DataFrame and issues report. Examples @@ -300,18 +304,18 @@ def check_survival_data_quality( def _print_summary( - summary: Dict[str, Any], + summary: dict[str, Any], time_col: str, status_col: str, - id_col: Optional[str], - covariate_cols: List[str], + id_col: str | None, + covariate_cols: list[str], ) -> None: """ Print a formatted summary of survival data. Parameters ---------- - summary : Dict[str, Any] + summary : dict[str, Any] Summary dictionary from summarize_survival_dataset. time_col : str Name of the time column. @@ -319,7 +323,7 @@ def _print_summary( Name of the status column. id_col : str, optional Name of the ID column. - covariate_cols : List[str] + covariate_cols : list[str] List of covariate column names. """ print("=" * 60) @@ -403,17 +407,17 @@ def _print_summary( def compare_survival_datasets( - datasets: Dict[str, pd.DataFrame], + datasets: dict[str, pd.DataFrame], time_col: str = "time", status_col: str = "status", - covariate_cols: Optional[List[str]] = None, + covariate_cols: list[str] | None = None, ) -> pd.DataFrame: """ Compare multiple survival datasets and summarize their differences. Parameters ---------- - datasets : Dict[str, pd.DataFrame] + datasets : dict[str, pd.DataFrame] Dictionary mapping dataset names to DataFrames. time_col : str, default="time" Name of the time column in each dataset. @@ -445,7 +449,7 @@ def compare_survival_datasets( >>> print(comparison) """ if not datasets: - raise ValueError("No datasets provided for comparison") + raise ParameterError("datasets", datasets, "at least one dataset is required") # Find common columns if covariate_cols not specified if covariate_cols is None: diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py index 3ea538c..a0c538b 100644 --- a/gen_surv/tdcm.py +++ b/gen_surv/tdcm.py @@ -1,12 +1,22 @@ import numpy as np import pandas as pd +from numpy.typing import NDArray +from typing import Sequence from gen_surv.bivariate import sample_bivariate_distribution -from gen_surv.censoring import rexpocens, runifcens +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens from gen_surv.validate import validate_gen_tdcm_inputs -def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): +def generate_censored_observations( + n: int, + dist_par: Sequence[float], + model_cens: str, + cens_par: float, + beta: Sequence[float], + lam: float, + b: NDArray[np.float64], +) -> NDArray[np.float64]: """ Generate censored TDCM observations. @@ -24,35 +34,42 @@ def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, - np.ndarray: Shape (n, 6) with columns: [id, start, stop, status, covariate1 (z1), covariate2 (z2)] """ - rfunc = runifcens if model_cens == "uniform" else rexpocens - observations = np.zeros((n, 6)) - - for k in range(n): - z1 = b[k, 1] - c = rfunc(1, cens_par)[0] - u = np.random.uniform() - - # Determine path based on u threshold - threshold = 1 - np.exp(-lam * b[k, 0] * np.exp(beta[0] * z1)) - if u < threshold: - t = -np.log(1 - u) / (lam * np.exp(beta[0] * z1)) - z2 = 0 - else: - t = ( - -np.log(1 - u) - + lam * b[k, 0] * np.exp(beta[0] * z1) * (1 - np.exp(beta[1])) - ) / (lam * np.exp(beta[0] * z1 + beta[1])) - z2 = 1 - - time = min(t, c) - status = int(t <= c) - - observations[k] = [k + 1, 0, time, status, z1, z2] - - return observations - - -def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens + + z1 = b[:, 1] + x = lam * b[:, 0] * np.exp(beta[0] * z1) + u = np.random.uniform(size=n) + c = rfunc(n, cens_par) + + threshold = 1 - np.exp(-x) + exp_b0_z1 = np.exp(beta[0] * z1) + log_term = -np.log(1 - u) + t1 = log_term / (lam * exp_b0_z1) + t2 = (log_term + x * (1 - np.exp(beta[1]))) / ( + lam * np.exp(beta[0] * z1 + beta[1]) + ) + mask = u < threshold + t = np.where(mask, t1, t2) + z2 = (~mask).astype(float) + + time = np.minimum(t, c) + status = (t <= c).astype(float) + + ids = np.arange(1, n + 1, dtype=float) + zeros = np.zeros(n, dtype=float) + return np.column_stack((ids, zeros, time, status, z1, z2)) + + +def gen_tdcm( + n: int, + dist: str, + corr: float, + dist_par: Sequence[float], + model_cens: str, + cens_par: float, + beta: Sequence[float], + lam: float, +) -> pd.DataFrame: """ Generate TDCM (Time-Dependent Covariate Model) survival data. diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py index c8b7aa7..a4e086a 100644 --- a/gen_surv/thmm.py +++ b/gen_surv/thmm.py @@ -1,13 +1,25 @@ import numpy as np import pandas as pd +from typing import Sequence, TypedDict -from gen_surv.censoring import rexpocens, runifcens +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens from gen_surv.validate import validate_gen_thmm_inputs +class TransitionTimes(TypedDict): + c: float + t12: float + t13: float + t23: float + + def calculate_transitions( - z1: float, cens_par: float, beta: list, rate: list, rfunc -) -> dict: + z1: float, + cens_par: float, + beta: Sequence[float], + rate: Sequence[float], + rfunc: CensoringFunc, +) -> TransitionTimes: """ Calculate transition and censoring times for THMM. @@ -33,7 +45,14 @@ def calculate_transitions( return {"c": c, "t12": t12, "t13": t13, "t23": t23} -def gen_thmm(n, model_cens, cens_par, beta, covariate_range, rate): +def gen_thmm( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], +) -> pd.DataFrame: """ Generate THMM (Time-Homogeneous Markov Model) survival data. @@ -49,7 +68,7 @@ def gen_thmm(n, model_cens, cens_par, beta, covariate_range, rate): - pd.DataFrame: Columns = ["id", "time", "state", "X0"] """ validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) - rfunc = runifcens if model_cens == "uniform" else rexpocens + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens records = [] for k in range(n): diff --git a/gen_surv/validate.py b/gen_surv/validate.py index 9a99792..9f23779 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -1,301 +1,328 @@ -def validate_gen_cphm_inputs( - n: int, model_cens: str, cens_par: float, covariate_range: float -): - """ - Validates input parameters for CPHM data generation. +"""Validation utilities for data generators.""" - Parameters: - - n (int): Number of data points to generate. - - model_cens (str): Censoring model, must be "uniform" or "exponential". - - cens_par (float): Parameter for the censoring model, must be > 0. - - covariate_range (float): Upper bound for covariate values, must be > 0. +from __future__ import annotations - Raises: - - ValueError: If any input is invalid. - """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'" - ) - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - if covariate_range <= 0: - raise ValueError("Argument 'covariate_range' must be greater than 0") +from collections.abc import Sequence + +from ._validation import ( + ensure_censoring_model, + ensure_in_choices, + ensure_numeric_sequence, + ensure_positive, + ensure_positive_int, + ensure_positive_sequence, + ensure_sequence_length, + ListOfListsError, + ParameterError, +) + + +_BETA_LEN = 3 +_CMM_RATE_LEN = 6 +_THMM_RATE_LEN = 3 +_WEIBULL_DIST_PAR_LEN = 4 +_EXP_DIST_PAR_LEN = 2 + + +def _validate_base(n: int, model_cens: str, cens_par: float) -> None: + """Common checks for sample size and censoring model.""" + ensure_positive_int(n, "n") + ensure_censoring_model(model_cens) + ensure_positive(cens_par, "cens_par") + + +def _validate_beta(beta: Sequence[float]) -> None: + """Ensure beta is a numeric sequence of length three.""" + ensure_sequence_length(beta, _BETA_LEN, "beta") + ensure_numeric_sequence(beta, "beta") + + +def _validate_aft_common( + n: int, beta: Sequence[float], model_cens: str, cens_par: float +) -> None: + """Shared validation logic for AFT generators.""" + _validate_base(n, model_cens, cens_par) + ensure_numeric_sequence(beta, "beta") + + +def validate_gen_cphm_inputs( + n: int, model_cens: str, cens_par: float, covariate_range: float +) -> None: + """Validate input parameters for CPHM data generation.""" + _validate_base(n, model_cens, cens_par) + ensure_positive(covariate_range, "covariate_range") def validate_gen_cmm_inputs( n: int, model_cens: str, cens_par: float, - beta: list, + beta: Sequence[float], covariate_range: float, - rate: list, -): - """ - Validate inputs for generating CMM (Continuous-Time Markov Model) data. - - Parameters: - - n (int): Number of individuals. - - model_cens (str): Censoring model, must be "uniform" or "exponential". - - cens_par (float): Parameter for censoring distribution, must be > 0. - - beta (list): Regression coefficients, must have length 3. - - covariate_range (float): Upper bound for covariate values, must be > 0. - - rate (list): Transition rates, must have length 6. - - Raises: - - ValueError: If any parameter is invalid. + rate: Sequence[float], +) -> None: + """Validate inputs for generating CMM (Continuous-Time Markov Model) data. + + Parameters + ---------- + n : int + Sample size. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + beta : Sequence[float] + Regression coefficients. + covariate_range : float + Range of the uniform covariate distribution. + rate : Sequence[float] + Six transition rate parameters. """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'" - ) - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - if len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3") - if covariate_range <= 0: - raise ValueError("Argument 'covariate_range' must be greater than 0") - if len(rate) != 6: - raise ValueError("Argument 'rate' must be a list of length 6") + + _validate_base(n, model_cens, cens_par) + _validate_beta(beta) + ensure_positive(covariate_range, "covariate_range") + ensure_sequence_length(rate, _CMM_RATE_LEN, "rate") def validate_gen_tdcm_inputs( n: int, dist: str, corr: float, - dist_par: list, + dist_par: Sequence[float], model_cens: str, cens_par: float, - beta: list, + beta: Sequence[float], lam: float, -): - """ - Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. - - Parameters: - - n (int): Number of observations. - - dist (str): "weibull" or "exponential". - - corr (float): Correlation coefficient. - - dist_par (list): Distribution parameters. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Censoring parameter. - - beta (list): Length-2 list of regression coefficients. - - lam (float): Lambda parameter, must be > 0. - - Raises: - - ValueError: For any invalid input. +) -> None: + """Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. + + Parameters + ---------- + n : int + Sample size. + dist : {"weibull", "exponential"} + Distribution used to generate correlated covariates. + corr : float + Correlation coefficient for the bivariate distribution. + dist_par : Sequence[float] + Parameters of the chosen distribution. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + beta : Sequence[float] + Regression coefficients. + lam : float + Baseline hazard rate. """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - if dist not in {"weibull", "exponential"}: - raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'") + _validate_base(n, model_cens, cens_par) + ensure_in_choices(dist, "dist", {"weibull", "exponential"}) if dist == "weibull": if not (0 < corr <= 1): - raise ValueError("With dist='weibull', 'corr' must be in (0,1]") - if len(dist_par) != 4 or any(p <= 0 for p in dist_par): - raise ValueError( - "With dist='weibull', 'dist_par' must be a positive list of length 4" - ) + raise ParameterError("corr", corr, "with dist='weibull' must be in (0,1]") + ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") + ensure_positive_sequence(dist_par, "dist_par") if dist == "exponential": if not (-1 <= corr <= 1): - raise ValueError("With dist='exponential', 'corr' must be in [-1,1]") - if len(dist_par) != 2 or any(p <= 0 for p in dist_par): - raise ValueError( - "With dist='exponential', 'dist_par' must be a positive list of length 2" - ) + raise ParameterError("corr", corr, "with dist='exponential' must be in [-1,1]") + ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") + ensure_positive_sequence(dist_par, "dist_par") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'" - ) - - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - - if not isinstance(beta, list) or len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3") - - if lam <= 0: - raise ValueError("Argument 'lambda' must be greater than 0") + _validate_beta(beta) + ensure_positive(lam, "lambda") def validate_gen_thmm_inputs( n: int, model_cens: str, cens_par: float, - beta: list, + beta: Sequence[float], covariate_range: float, - rate: list, -): - """ - Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. - - Parameters: - - n (int): Number of samples, must be > 0. - - model_cens (str): Must be "uniform" or "exponential". - - cens_par (float): Must be > 0. - - beta (list): List of length 3 (regression coefficients). - - covariate_range (float): Positive upper bound for covariate values. - - rate (list): List of length 3 (transition rates). - - Raises: - - ValueError if any input is invalid. + rate: Sequence[float], +) -> None: + """Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. + + Parameters + ---------- + n : int + Sample size. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + beta : Sequence[float] + Regression coefficients. + covariate_range : float + Range of the uniform covariate distribution. + rate : Sequence[float] + Three transition rate parameters. """ - if not isinstance(n, int) or n <= 0: - raise ValueError("Argument 'n' must be a positive integer.") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'" - ) - - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("Argument 'cens_par' must be a positive number.") - - if not isinstance(beta, list) or len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3.") - - if not isinstance(covariate_range, (int, float)) or covariate_range <= 0: - raise ValueError("Argument 'covariate_range' must be greater than 0.") - - if not isinstance(rate, list) or len(rate) != 3: - raise ValueError("Argument 'rate' must be a list of length 3.") - - -def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): + _validate_base(n, model_cens, cens_par) + _validate_beta(beta) + ensure_positive(covariate_range, "covariate_range") + ensure_sequence_length(rate, _THMM_RATE_LEN, "rate") + + +def validate_dg_biv_inputs( + n: int, dist: str, corr: float, dist_par: Sequence[float] +) -> None: + """Validate inputs for the :func:`sample_bivariate_distribution` helper. + + Parameters + ---------- + n : int + Number of samples. + dist : {"weibull", "exponential"} + Bivariate marginal distribution. + corr : float + Correlation coefficient. + dist_par : Sequence[float] + Parameters for the selected distribution. """ - Validate inputs for the sample_bivariate_distribution function. - Parameters: - - n (int): Number of samples to generate. - - dist (str): Must be "weibull" or "exponential". - - corr (float): Must be between -1 and 1. - - dist_par (list): Must contain positive values, and correct length for the distribution. - - Raises: - - ValueError if any input is invalid. - """ - if not isinstance(n, int) or n <= 0: - raise ValueError("Argument 'n' must be a positive integer.") - - if dist not in {"weibull", "exponential"}: - raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'.") + ensure_positive_int(n, "n") + ensure_in_choices(dist, "dist", {"weibull", "exponential"}) if not isinstance(corr, (int, float)) or not (-1 < corr < 1): - raise ValueError("Argument 'corr' must be a numeric value between -1 and 1.") - - if not isinstance(dist_par, list) or len(dist_par) == 0: - raise ValueError( - "Argument 'dist_par' must be a non-empty list of positive values." - ) - - if any(p <= 0 for p in dist_par): - raise ValueError("All elements in 'dist_par' must be greater than 0.") - - if dist == "exponential" and len(dist_par) != 2: - raise ValueError( - "Exponential distribution requires exactly 2 positive parameters." - ) - - if dist == "weibull" and len(dist_par) != 4: - raise ValueError("Weibull distribution requires exactly 4 positive parameters.") - - -def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par): - if not isinstance(n, int) or n <= 0: - raise ValueError("n must be a positive integer") - - if not isinstance(beta, (list, tuple)) or not all( - isinstance(b, (int, float)) for b in beta - ): - raise ValueError("beta must be a list of numbers") - - if not isinstance(sigma, (int, float)) or sigma <= 0: - raise ValueError("sigma must be a positive number") - - if model_cens not in ("uniform", "exponential"): - raise ValueError("model_cens must be 'uniform' or 'exponential'") - - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("cens_par must be a positive number") + raise ParameterError("corr", corr, "must be a numeric value between -1 and 1") + ensure_positive_sequence(dist_par, "dist_par") + if dist == "exponential": + ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") + if dist == "weibull": + ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") -def validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par): - if not isinstance(n, int) or n <= 0: - raise ValueError("n must be a positive integer") - if not isinstance(beta, (list, tuple)) or not all( - isinstance(b, (int, float)) for b in beta - ): - raise ValueError("beta must be a list of numbers") +def validate_gen_aft_log_normal_inputs( + n: int, + beta: Sequence[float], + sigma: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the log-normal AFT generator. + + Parameters + ---------- + n : int + Sample size. + beta : Sequence[float] + Regression coefficients. + sigma : float + Scale parameter of the log-normal distribution. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + """ - if not isinstance(shape, (int, float)) or shape <= 0: - raise ValueError("shape must be a positive number") + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(sigma, "sigma") - if not isinstance(scale, (int, float)) or scale <= 0: - raise ValueError("scale must be a positive number") - if model_cens not in ("uniform", "exponential"): - raise ValueError("model_cens must be 'uniform' or 'exponential'") +def validate_gen_aft_weibull_inputs( + n: int, + beta: Sequence[float], + shape: float, + scale: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the Weibull AFT generator. + + Parameters + ---------- + n : int + Sample size. + beta : Sequence[float] + Regression coefficients. + shape : float + Shape parameter of the Weibull distribution. + scale : float + Scale parameter of the Weibull distribution. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + """ - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("cens_par must be a positive number") + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") -def validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens_par): - if not isinstance(n, int) or n <= 0: - raise ValueError("n must be a positive integer") +def validate_gen_aft_log_logistic_inputs( + n: int, + beta: Sequence[float], + shape: float, + scale: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the log-logistic AFT generator. + + Parameters + ---------- + n : int + Sample size. + beta : Sequence[float] + Regression coefficients. + shape : float + Shape parameter of the log-logistic distribution. + scale : float + Scale parameter of the log-logistic distribution. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + """ - if not isinstance(beta, (list, tuple)) or not all( - isinstance(b, (int, float)) for b in beta - ): - raise ValueError("beta must be a list of numbers") + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") - if not isinstance(shape, (int, float)) or shape <= 0: - raise ValueError("shape must be a positive number") - if not isinstance(scale, (int, float)) or scale <= 0: - raise ValueError("scale must be a positive number") +def validate_competing_risks_inputs( + n: int, + n_risks: int, + baseline_hazards: Sequence[float] | None, + betas: Sequence[Sequence[float]] | None, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for competing risks data generation. + + Parameters + ---------- + n : int + Sample size. + n_risks : int + Number of competing risks. + baseline_hazards : Sequence[float] or None + Baseline hazard for each risk. + betas : Sequence[Sequence[float]] or None + Regression coefficients for each risk. + model_cens : str + Censoring model identifier. + cens_par : float + Censoring distribution parameter. + """ - if model_cens not in ("uniform", "exponential"): - raise ValueError("model_cens must be 'uniform' or 'exponential'") + _validate_base(n, model_cens, cens_par) + ensure_positive_int(n_risks, "n_risks") - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("cens_par must be a positive number") + if baseline_hazards is not None: + ensure_sequence_length(baseline_hazards, n_risks, "baseline_hazards") + ensure_positive_sequence(baseline_hazards, "baseline_hazards") + if betas is not None: + if not isinstance(betas, list) or any(not isinstance(b, list) for b in betas): + raise ListOfListsError("betas", betas) + for b in betas: + ensure_numeric_sequence(b, "betas") -def validate_competing_risks_inputs( - n, n_risks, baseline_hazards, betas, model_cens, cens_par -): - if not isinstance(n, int) or n <= 0: - raise ValueError("n must be a positive integer") - - if not isinstance(n_risks, int) or n_risks <= 0: - raise ValueError("n_risks must be a positive integer") - - if baseline_hazards is not None and ( - not isinstance(baseline_hazards, (list, tuple)) - or len(baseline_hazards) != n_risks - or any(h <= 0 for h in baseline_hazards) - ): - raise ValueError( - "baseline_hazards must be a list of positive numbers with length n_risks" - ) - - if betas is not None and ( - not isinstance(betas, list) or any(not isinstance(b, list) for b in betas) - ): - raise ValueError("betas must be a list of lists") - - if model_cens not in ("uniform", "exponential"): - raise ValueError("model_cens must be 'uniform' or 'exponential'") - - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("cens_par must be a positive number") diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py index c3bcb89..3a80c11 100644 --- a/gen_surv/visualization.py +++ b/gen_surv/visualization.py @@ -7,8 +7,6 @@ survival analysis. """ -from typing import Dict, Optional, Tuple - import matplotlib.pyplot as plt import pandas as pd from matplotlib.axes import Axes @@ -19,12 +17,12 @@ def plot_survival_curve( data: pd.DataFrame, time_col: str = "time", status_col: str = "status", - group_col: Optional[str] = None, + group_col: str | None = None, confidence_intervals: bool = True, title: str = "Kaplan-Meier Survival Curve", - figsize: Tuple[float, float] = (10, 6), + figsize: tuple[float, float] = (10, 6), ci_alpha: float = 0.2, -) -> Tuple[Figure, Axes]: +) -> tuple[Figure, Axes]: """ Plot Kaplan-Meier survival curves from simulated data. @@ -128,13 +126,13 @@ def plot_survival_curve( def plot_hazard_comparison( - models: Dict[str, pd.DataFrame], + models: dict[str, pd.DataFrame], time_col: str = "time", status_col: str = "status", title: str = "Hazard Function Comparison", - figsize: Tuple[float, float] = (10, 6), + figsize: tuple[float, float] = (10, 6), bandwidth: float = 0.5, -) -> Tuple[Figure, Axes]: +) -> tuple[Figure, Axes]: """ Compare hazard functions from multiple generated datasets. @@ -215,9 +213,9 @@ def plot_covariate_effect( status_col: str = "status", n_groups: int = 3, title: str = "Effect of Covariate on Survival", - figsize: Tuple[float, float] = (10, 6), + figsize: tuple[float, float] = (10, 6), ci_alpha: float = 0.2, -) -> Tuple[Figure, Axes]: +) -> tuple[Figure, Axes]: """ Visualize the effect of a continuous covariate on survival by discretizing it. diff --git a/tests/test_aft.py b/tests/test_aft.py index 0cf4a18..d3b747f 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -12,6 +12,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull +from gen_surv._validation import ChoiceError, PositiveValueError def test_gen_aft_log_logistic_runs(): @@ -37,7 +38,7 @@ def test_gen_aft_log_logistic_runs(): def test_gen_aft_log_logistic_invalid_shape(): """Test that the Log-Logistic AFT generator raises error for invalid shape.""" - with pytest.raises(ValueError, match="shape parameter must be positive"): + with pytest.raises(PositiveValueError): gen_aft_log_logistic( n=10, beta=[0.5, -0.2], @@ -51,7 +52,7 @@ def test_gen_aft_log_logistic_invalid_shape(): def test_gen_aft_log_logistic_invalid_scale(): """Test that the Log-Logistic AFT generator raises error for invalid scale.""" - with pytest.raises(ValueError, match="scale parameter must be positive"): + with pytest.raises(PositiveValueError): gen_aft_log_logistic( n=10, beta=[0.5, -0.2], @@ -167,7 +168,7 @@ def test_gen_aft_weibull_runs(): def test_gen_aft_weibull_invalid_shape(): """Test that the Weibull AFT generator raises error for invalid shape.""" - with pytest.raises(ValueError, match="shape parameter must be positive"): + with pytest.raises(PositiveValueError): gen_aft_weibull( n=10, beta=[0.5, -0.2], @@ -180,7 +181,7 @@ def test_gen_aft_weibull_invalid_shape(): def test_gen_aft_weibull_invalid_scale(): """Test that the Weibull AFT generator raises error for invalid scale.""" - with pytest.raises(ValueError, match="scale parameter must be positive"): + with pytest.raises(PositiveValueError): gen_aft_weibull( n=10, beta=[0.5, -0.2], @@ -193,9 +194,7 @@ def test_gen_aft_weibull_invalid_scale(): def test_gen_aft_weibull_invalid_cens_model(): """Test that the Weibull AFT generator raises error for invalid censoring model.""" - with pytest.raises( - ValueError, match="model_cens must be 'uniform' or 'exponential'" - ): + with pytest.raises(ChoiceError): gen_aft_weibull( n=10, beta=[0.5, -0.2], diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py index 403cdbb..d2362ea 100644 --- a/tests/test_bivariate.py +++ b/tests/test_bivariate.py @@ -7,6 +7,7 @@ import pytest from gen_surv.bivariate import sample_bivariate_distribution +from gen_surv._validation import ChoiceError, LengthError def test_sample_bivariate_exponential_shape(): @@ -17,18 +18,18 @@ def test_sample_bivariate_exponential_shape(): def test_sample_bivariate_invalid_dist(): - """Unsupported distributions should raise ValueError.""" - with pytest.raises(ValueError): + """Unsupported distributions should raise ChoiceError.""" + with pytest.raises(ChoiceError): sample_bivariate_distribution(10, "invalid", 0.0, [1, 1]) def test_sample_bivariate_exponential_param_length_error(): - """Exponential distribution with wrong param length should raise ValueError.""" - with pytest.raises(ValueError): + """Exponential distribution with wrong param length should raise LengthError.""" + with pytest.raises(LengthError): sample_bivariate_distribution(5, "exponential", 0.0, [1.0]) def test_sample_bivariate_weibull_param_length_error(): - """Weibull distribution with wrong param length should raise ValueError.""" - with pytest.raises(ValueError): + """Weibull distribution with wrong param length should raise LengthError.""" + with pytest.raises(LengthError): sample_bivariate_distribution(5, "weibull", 0.0, [1.0, 1.0]) diff --git a/tests/test_censoring.py b/tests/test_censoring.py index 8db38c7..d94b85d 100644 --- a/tests/test_censoring.py +++ b/tests/test_censoring.py @@ -1,6 +1,15 @@ import numpy as np -from gen_surv.censoring import rexpocens, runifcens +from gen_surv.censoring import ( + WeibullCensoring, + LogNormalCensoring, + GammaCensoring, + rexpocens, + runifcens, + rweibcens, + rlognormcens, + rgammacens, +) def test_runifcens_range(): @@ -16,3 +25,48 @@ def test_rexpocens_nonnegative(): assert isinstance(times, np.ndarray) assert len(times) == 5 assert np.all(times >= 0) + + +def test_rweibcens_nonnegative(): + times = rweibcens(5, 1.0, 1.5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + + +def test_rlognormcens_positive(): + times = rlognormcens(5, 0.0, 1.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_rgammacens_positive(): + times = rgammacens(5, 2.0, 1.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_weibull_censoring_class(): + model = WeibullCensoring(scale=1.0, shape=1.5) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + + +def test_lognormal_censoring_class(): + model = LogNormalCensoring(mean=0.0, sigma=1.0) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_gamma_censoring_class(): + model = GammaCensoring(shape=2.0, scale=1.0) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py index 10c0574..193b442 100644 --- a/tests/test_competing_risks.py +++ b/tests/test_competing_risks.py @@ -14,6 +14,7 @@ gen_competing_risks, gen_competing_risks_weibull, ) +from gen_surv._validation import ChoiceError, LengthError, ParameterError def test_gen_competing_risks_basic(): @@ -60,7 +61,7 @@ def test_gen_competing_risks_weibull_basic(): def test_competing_risks_parameters(): """Test parameter validation in competing risks model.""" # Test with invalid number of baseline hazards - with pytest.raises(ValueError, match="Expected 3 baseline hazards"): + with pytest.raises(LengthError): gen_competing_risks( n=10, n_risks=3, @@ -69,7 +70,7 @@ def test_competing_risks_parameters(): ) # Test with invalid number of beta coefficient sets - with pytest.raises(ValueError, match="Expected 2 sets of coefficients"): + with pytest.raises(LengthError): gen_competing_risks( n=10, n_risks=2, @@ -78,16 +79,14 @@ def test_competing_risks_parameters(): ) # Test with invalid censoring model - with pytest.raises( - ValueError, match="model_cens must be 'uniform' or 'exponential'" - ): + with pytest.raises(ChoiceError): gen_competing_risks(n=10, n_risks=2, model_cens="invalid", seed=42) def test_competing_risks_weibull_parameters(): """Test parameter validation in Weibull competing risks model.""" # Test with invalid number of shape parameters - with pytest.raises(ValueError, match="Expected 3 shape parameters"): + with pytest.raises(LengthError): gen_competing_risks_weibull( n=10, n_risks=3, @@ -96,7 +95,7 @@ def test_competing_risks_weibull_parameters(): ) # Test with invalid number of scale parameters - with pytest.raises(ValueError, match="Expected 3 scale parameters"): + with pytest.raises(LengthError): gen_competing_risks_weibull( n=10, n_risks=3, @@ -123,7 +122,7 @@ def test_cause_specific_cumulative_incidence(): assert cif["incidence"].is_monotonic_increasing # Test with invalid cause - with pytest.raises(ValueError, match="Cause 3 not found in the data"): + with pytest.raises(ParameterError): cause_specific_cumulative_incidence(df, time_points, cause=3) diff --git a/tests/test_summary_more.py b/tests/test_summary_more.py index e039eae..0100b93 100644 --- a/tests/test_summary_more.py +++ b/tests/test_summary_more.py @@ -6,18 +6,19 @@ check_survival_data_quality, summarize_survival_dataset, ) +from gen_surv._validation import ParameterError def test_summarize_survival_dataset_errors(): df = pd.DataFrame({"time": [1, 2], "status": [1, 0]}) # Missing time column - with pytest.raises(ValueError): + with pytest.raises(ParameterError): summarize_survival_dataset(df.drop(columns=["time"])) # Missing ID column when specified - with pytest.raises(ValueError): + with pytest.raises(ParameterError): summarize_survival_dataset(df, id_col="id") # Missing covariate columns - with pytest.raises(ValueError): + with pytest.raises(ParameterError): summarize_survival_dataset(df, covariate_cols=["bad"]) diff --git a/tests/test_validate.py b/tests/test_validate.py index 8fa6e9b..c7110f7 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,6 +1,13 @@ import pytest import gen_surv.validate as v +from gen_surv._validation import ( + ChoiceError, + ParameterError, + PositiveIntegerError, + ensure_censoring_model, + ensure_positive_int, +) def test_validate_gen_cphm_inputs_valid(): @@ -212,3 +219,27 @@ def test_validate_gen_thmm_inputs_invalid( def test_validate_gen_thmm_inputs_valid(): v.validate_gen_thmm_inputs(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]) + + +def test_positive_integer_error(): + with pytest.raises(PositiveIntegerError): + ensure_positive_int(-1, "n") + + +def test_censoring_model_choice_error(): + with pytest.raises(ChoiceError): + ensure_censoring_model("bad") + + +def test_parameter_error_from_validator(): + with pytest.raises(ParameterError): + v.validate_gen_tdcm_inputs( + 1, + "weibull", + 0.0, + [1, 2, 3, 4], + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + )