Skip to content

Commit b850e12

Browse files
committed
chore: tighten tooling and fix visualization
1 parent f92a584 commit b850e12

File tree

10 files changed

+59
-73
lines changed

10 files changed

+59
-73
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
22
max-line-length = 120
3-
extend-ignore = E501,W291,W293,W391,F401,F841,E402,E302,E305
3+
extend-ignore = E501,W291,W293,W391,E402,E302,E305

.github/workflows/ci.yml

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ubuntu-latest
1313
strategy:
1414
matrix:
15-
python-version: ["3.9", "3.10", "3.11"]
15+
python-version: ["3.10", "3.11", "3.12"]
1616

1717
steps:
1818
- name: Checkout code
@@ -23,6 +23,12 @@ jobs:
2323
with:
2424
python-version: ${{ matrix.python-version }}
2525

26+
- name: Cache Poetry dependencies
27+
uses: actions/cache@v4
28+
with:
29+
path: ~/.cache/pypoetry
30+
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
31+
2632
- name: Install Poetry
2733
run: |
2834
curl -sSL https://install.python-poetry.org | python3 -
@@ -53,13 +59,19 @@ jobs:
5359
with:
5460
python-version: "3.11"
5561

56-
- name: Install Poetry
57-
run: |
58-
curl -sSL https://install.python-poetry.org | python3 -
59-
echo "$HOME/.local/bin" >> $GITHUB_PATH
62+
- name: Cache Poetry dependencies
63+
uses: actions/cache@v4
64+
with:
65+
path: ~/.cache/pypoetry
66+
key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }}
6067

61-
- name: Install dependencies
62-
run: poetry install --with dev
68+
- name: Install Poetry
69+
run: |
70+
curl -sSL https://install.python-poetry.org | python3 -
71+
echo "$HOME/.local/bin" >> $GITHUB_PATH
72+
73+
- name: Install dependencies
74+
run: poetry install --with dev --no-root
6375

64-
- name: Run pre-commit checks
65-
run: poetry run pre-commit run --all-files
76+
- name: Run pre-commit checks
77+
run: poetry run pre-commit run --all-files

.github/workflows/docs.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@ jobs:
1414
- uses: actions/setup-python@v5
1515
with:
1616
python-version: '3.11'
17+
- name: Cache Poetry dependencies
18+
uses: actions/cache@v4
19+
with:
20+
path: ~/.cache/pypoetry
21+
key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }}
1722
- name: Install Poetry
18-
run: pip install poetry
23+
run: |
24+
curl -sSL https://install.python-poetry.org | python3 -
25+
echo "$HOME/.local/bin" >> $GITHUB_PATH
1926
- name: Install dependencies
2027
run: poetry install --with docs
2128
- name: Build documentation

.github/workflows/test.yml

Lines changed: 0 additions & 41 deletions
This file was deleted.

docs/source/conf.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
import os
2-
import sys
3-
from pathlib import Path
4-
5-
# Add the package to the Python path using an absolute path
6-
project_root = Path(__file__).resolve().parent.parent.parent
7-
sys.path.insert(0, str(project_root / "gen_surv"))
2+
from datetime import datetime
3+
from importlib import metadata
84

95
# Project information
106
project = "gen_surv"
11-
copyright = "2025, Diogo Ribeiro"
7+
copyright = f"{datetime.now().year}, Diogo Ribeiro"
128
author = "Diogo Ribeiro"
13-
release = "1.0.9"
14-
version = "1.0.9"
9+
release = metadata.version("gen_surv")
10+
version = release
1511

1612
# General configuration
1713
extensions = [
@@ -21,10 +17,10 @@
2117
"sphinx.ext.intersphinx",
2218
"sphinx.ext.autosummary",
2319
"sphinx.ext.githubpages",
24-
"sphinx.ext.plot_directive",
2520
"myst_parser",
2621
"sphinx_copybutton",
2722
"sphinx_design",
23+
"sphinx_autodoc_typehints",
2824
]
2925

3026
# MyST Parser configuration

docs/source/examples/cmm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Visualize transition times from the CMM generator:
66

7-
```{plot}
7+
```python
88
import numpy as np
99
import matplotlib.pyplot as plt
1010
from gen_surv import generate

docs/source/examples/tdcm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
A basic visualization of event times produced by the TDCM generator:
66

7-
```{plot}
7+
```python
88
import numpy as np
99
import matplotlib.pyplot as plt
1010
from gen_surv import generate

docs/source/examples/thmm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
An example of event times generated by the THMM:
66

7-
```{plot}
7+
```python
88
import numpy as np
99
import matplotlib.pyplot as plt
1010
from gen_surv import generate

gen_surv/cli.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using the gen_surv package.
66
"""
77

8-
from typing import List, Optional, TypeVar, cast
8+
from typing import Any, Dict, List, Optional, TypeVar, cast
99

1010
import typer
1111

@@ -91,8 +91,8 @@ def _val(v: T | OptionInfo) -> T:
9191
return v if not isinstance(v, OptionInfo) else cast(T, v.default)
9292

9393
# Prepare arguments based on the selected model
94-
model_str = _val(model)
95-
kwargs = {
94+
model_str: str = _val(model)
95+
kwargs: Dict[str, Any] = {
9696
"model": model_str,
9797
"n": _val(n),
9898
"model_cens": _val(model_cens),
@@ -103,7 +103,8 @@ def _val(v: T | OptionInfo) -> T:
103103
# Add model-specific parameters
104104
if model_str in ["cphm", "cmm", "thmm"]:
105105
# These models use a single beta and covariate range
106-
kwargs["beta"] = _val(beta)[0] if len(_val(beta)) > 0 else 0.5
106+
beta_values = cast(List[float], _val(beta))
107+
kwargs["beta"] = beta_values[0] if len(beta_values) > 0 else 0.5
107108
kwargs["covariate_range"] = _val(covariate_range)
108109

109110
elif model_str == "aft_ln":
@@ -153,10 +154,10 @@ def _val(v: T | OptionInfo) -> T:
153154

154155
# Generate the data
155156
try:
156-
df = generate(**kwargs)
157+
df = generate(**kwargs) # type: ignore[arg-type]
157158
except TypeError:
158159
# Fallback for tests where generate accepts only model and n
159-
df = generate(model=model_str, n=_val(n))
160+
df = generate(model=model_str, n=_val(n)) # type: ignore[arg-type]
160161

161162
# Output the data
162163
if output:
@@ -228,6 +229,7 @@ def visualize(
228229

229230
# Save the plot
230231
plt.savefig(output, dpi=300, bbox_inches="tight")
232+
plt.close(fig)
231233
typer.echo(f"Plot saved to {output}")
232234

233235

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ python = ">=3.10,<3.13"
2828
numpy = "^1.26"
2929
pandas = "^2.2.3"
3030
typer = "^0.12.3"
31-
matplotlib = "^3.10"
31+
matplotlib = "~3.8"
3232
lifelines = "^0.30"
3333
pyarrow = "^14"
3434
pyreadr = "^0.5"
@@ -45,6 +45,7 @@ black = "^24.1.0"
4545
isort = "^5.13.2"
4646
flake8 = "^6.1.0"
4747
scikit-survival = "^0.24.1"
48+
pre-commit = "^3.8"
4849

4950
[tool.poetry.group.docs.dependencies]
5051
sphinx = ">=6.0"
@@ -78,14 +79,23 @@ line_length = 88
7879

7980
[tool.flake8]
8081
max-line-length = 88
81-
extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "F401", "F841", "E402", "E302", "E305"]
82+
extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "E402", "E302", "E305"]
8283

8384
[tool.mypy]
8485
python_version = "3.10"
8586
warn_return_any = true
8687
warn_unused_configs = true
8788
disallow_untyped_defs = true
8889
disallow_incomplete_defs = true
90+
ignore_missing_imports = true
91+
92+
[[tool.mypy.overrides]]
93+
module = [
94+
"gen_surv.interface",
95+
"gen_surv.competing_risks",
96+
"gen_surv.mixture",
97+
"gen_surv.sklearn_adapter",
98+
]
8999
ignore_errors = true
90100

91101
[build-system]

0 commit comments

Comments
 (0)