Skip to content

Commit 0f2c560

Browse files
committed
Add CLI and Hypothesis tests
1 parent 7c84d2b commit 0f2c560

File tree

5 files changed

+76
-28
lines changed

5 files changed

+76
-28
lines changed

CITATION.cff

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
cff-version: 1.2.0
2+
message: "If you use this software, please cite it using the metadata below."
3+
4+
# Basic information
5+
preferred-citation:
6+
type: software
7+
title: "gen_surv"
8+
version: "1.0.1"
9+
url: "https://github.com/DiogoRibeiro7/genSurvPy"
10+
authors:
11+
- family-names: Ribeiro
12+
given-names: Diogo
13+
orcid: "https://orcid.org/0009-0001-2022-7072"
14+
affiliation: "ESMAD - Instituto Politécnico do Porto"
15+
16+
license: "MIT"
17+
date-released: "2024-01-01"
18+

TODO.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ This document outlines future enhancements, features, and ideas for improving th
44

55
---
66

7+
## ✨ Priority Items
8+
9+
- [] Add property-based tests using Hypothesis to cover edge cases
10+
- [] Build a CLI for generating datasets from the terminal
11+
- [ ] Expand documentation with multilingual support and more usage examples
12+
- [ ] Implement Weibull and log-logistic AFT models and add visualization utilities
13+
- [] Provide CITATION metadata for proper referencing
14+
- [ ] Ensure all functions include Google-style docstrings with inline comments
15+
16+
---
17+
718
## 📦 1. Interface and UX
819

920
- [] Create a `generate(..., return_type="df" | "dict")` interface

gen_surv/__main__.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,4 @@
1-
import argparse
2-
import pandas as pd
3-
from gen_surv.cphm import gen_cphm
4-
from gen_surv.cmm import gen_cmm
5-
from gen_surv.tdcm import gen_tdcm
6-
from gen_surv.thmm import gen_thmm
7-
8-
def run_example(model: str):
9-
if model == "cphm":
10-
df = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0)
11-
elif model == "cmm":
12-
df = gen_cmm(n=10, model_cens="exponential", cens_par=1.0,
13-
beta=[0.5, 0.2, -0.1], covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0])
14-
elif model == "tdcm":
15-
df = gen_tdcm(n=10, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2],
16-
model_cens="uniform", cens_par=0.5, beta=[0.1, 0.2, 0.3], lam=1.0)
17-
elif model == "thmm":
18-
df = gen_thmm(n=10, model_cens="uniform", cens_par=0.5,
19-
beta=[0.1, 0.2, 0.3], covar=1.0, rate=[0.5, 0.6, 0.7])
20-
else:
21-
raise ValueError(f"Unknown model: {model}")
22-
23-
print(df)
1+
from gen_surv.cli import app
242

253
if __name__ == "__main__":
26-
parser = argparse.ArgumentParser(description="Run gen_surv model example.")
27-
parser.add_argument("model", choices=["cphm", "cmm", "tdcm", "thmm"],
28-
help="Model to run (cphm, cmm, tdcm, thmm)")
29-
args = parser.parse_args()
30-
run_example(args.model)
4+
app()

gen_surv/cli.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import csv
2+
from typing import Optional
3+
import typer
4+
from gen_surv.interface import generate
5+
6+
app = typer.Typer(help="Generate synthetic survival datasets.")
7+
8+
@app.command()
9+
def dataset(
10+
model: str = typer.Argument(..., help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln]"),
11+
n: int = typer.Option(100, help="Number of samples"),
12+
output: Optional[str] = typer.Option(None, "-o", help="Output CSV file. Prints to stdout if omitted."),
13+
):
14+
"""Generate survival data and optionally save to CSV."""
15+
df = generate(model=model, n=n)
16+
if output:
17+
df.to_csv(output, index=False)
18+
typer.echo(f"Saved dataset to {output}")
19+
else:
20+
typer.echo(df.to_csv(index=False))
21+
22+
if __name__ == "__main__":
23+
app()

tests/test_aft_property.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from hypothesis import given, strategies as st
2+
from gen_surv.aft import gen_aft_log_normal
3+
4+
@given(
5+
n=st.integers(min_value=1, max_value=20),
6+
sigma=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False),
7+
cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False),
8+
seed=st.integers(min_value=0, max_value=1000)
9+
)
10+
def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed):
11+
df = gen_aft_log_normal(
12+
n=n,
13+
beta=[0.5, -0.2],
14+
sigma=sigma,
15+
model_cens="uniform",
16+
cens_par=cens_par,
17+
seed=seed
18+
)
19+
assert df.shape[0] == n
20+
assert set(df["status"].unique()).issubset({0, 1})
21+
assert (df["time"] >= 0).all()
22+
assert df.filter(regex="^X[0-9]+$").shape[1] == 2

0 commit comments

Comments
 (0)