Skip to content

Commit 5a184ca

Browse files
feat: add dataset export utility (#41)
1 parent 4ba1b79 commit 5a184ca

File tree

6 files changed

+76
-4
lines changed

6 files changed

+76
-4
lines changed

CITATION.cff

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ preferred-citation:
1010
authors:
1111
- family-names: Ribeiro
1212
given-names: Diogo
13+
alias: DiogoRibeiro7
1314
orcid: "https://orcid.org/0009-0001-2022-7072"
1415
affiliation: "ESMAD - Instituto Politécnico do Porto"
1516

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
```bash
2222
poetry install
2323
```
24+
This package requires **Python 3.10** or later.
2425
## ✨ Features
2526

2627
- Consistent interface across models
@@ -31,6 +32,7 @@ poetry install
3132
- Mixture cure and piecewise exponential models
3233
- Competing risks generators (constant and Weibull hazards)
3334
- Command-line interface powered by `Typer`
35+
- Export utilities for CSV, JSON, and Feather formats
3436

3537
## 🧪 Example
3638

@@ -98,6 +100,7 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv
98100
| `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times |
99101
| `runifcens()` | Generate uniform censoring times |
100102
| `rexpocens()` | Generate exponential censoring times |
103+
| `export_dataset()` | Save a dataset to CSV, JSON or Feather |
101104

102105

103106
```text

gen_surv/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .competing_risks import gen_competing_risks, gen_competing_risks_weibull
1818
from .mixture import gen_mixture_cure, cure_fraction_estimate
1919
from .piecewise import gen_piecewise_exponential
20+
from .export import export_dataset
2021

2122
# Helper functions
2223
from .bivariate import sample_bivariate_distribution
@@ -61,6 +62,7 @@
6162
"sample_bivariate_distribution",
6263
"runifcens",
6364
"rexpocens",
65+
"export_dataset",
6466
]
6567

6668
# Add visualization tools to __all__ if available

gen_surv/export.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Data export utilities for gen_surv.
2+
3+
This module provides helper functions to save generated
4+
survival datasets in various formats.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import os
10+
from typing import Optional
11+
12+
import pandas as pd
13+
14+
15+
def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> None:
16+
"""Save a DataFrame to disk.
17+
18+
Parameters
19+
----------
20+
df : pd.DataFrame
21+
DataFrame containing survival data.
22+
path : str
23+
File path to write to. The extension is used to infer the format
24+
when ``fmt`` is ``None``.
25+
fmt : {"csv", "json", "feather"}, optional
26+
Format to use. If omitted, inferred from ``path``.
27+
28+
Raises
29+
------
30+
ValueError
31+
If the format is not one of the supported types.
32+
"""
33+
if fmt is None:
34+
fmt = os.path.splitext(path)[1].lstrip(".").lower()
35+
36+
if fmt == "csv":
37+
df.to_csv(path, index=False)
38+
elif fmt == "json":
39+
df.to_json(path, orient="table")
40+
elif fmt in {"feather", "ft"}:
41+
df.reset_index(drop=True).to_feather(path)
42+
else:
43+
raise ValueError(f"Unsupported export format: {fmt}")
44+

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ classifiers = [
1717
"Topic :: Scientific/Engineering :: Medical Science Apps.",
1818
"Topic :: Scientific/Engineering :: Mathematics",
1919
"Programming Language :: Python :: 3",
20-
"Programming Language :: Python :: 3.9",
2120
"Programming Language :: Python :: 3.10",
2221
"Programming Language :: Python :: 3.11",
22+
"Programming Language :: Python :: 3.12",
2323
"License :: OSI Approved :: MIT License",
2424
]
2525

2626
[tool.poetry.dependencies]
27-
python = "^3.9"
27+
python = ">=3.10,<3.13"
2828
numpy = "^1.26"
2929
pandas = "^2.2.3"
3030
typer = "^0.12.3"
@@ -62,7 +62,7 @@ build_command = ""
6262

6363
[tool.black]
6464
line-length = 88
65-
target-version = ['py39']
65+
target-version = ['py310']
6666
include = '\.pyi?$'
6767

6868
[tool.isort]
@@ -74,7 +74,7 @@ max-line-length = 88
7474
extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "F401", "F841", "E402", "E302", "E305"]
7575

7676
[tool.mypy]
77-
python_version = "3.9"
77+
python_version = "3.10"
7878
warn_return_any = true
7979
warn_unused_configs = true
8080
disallow_untyped_defs = true

tests/test_export.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import pandas as pd
3+
from gen_surv import generate, export_dataset
4+
5+
6+
def test_export_dataset_csv(tmp_path):
7+
df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
8+
out_file = tmp_path / "data.csv"
9+
export_dataset(df, str(out_file))
10+
assert out_file.exists()
11+
loaded = pd.read_csv(out_file)
12+
pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded)
13+
14+
15+
def test_export_dataset_json(tmp_path):
16+
df = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
17+
out_file = tmp_path / "data.json"
18+
export_dataset(df, str(out_file))
19+
assert out_file.exists()
20+
loaded = pd.read_json(out_file, orient="table")
21+
pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded)
22+

0 commit comments

Comments
 (0)