Skip to content

Commit 79f6a8b

Browse files
committed
chore: fixes unit test
1 parent db49738 commit 79f6a8b

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

gen_surv/visualization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""
22
Visualization utilities for survival data.
33
4-
This module provides functions to visualize survival data generated by gen_surv,
5-
including Kaplan-Meier survival curves and other commonly used plots in survival analysis.
4+
This module provides functions to visualize survival data generated by
5+
gen_surv,
6+
including Kaplan-Meier survival curves and other commonly used plots in
7+
survival analysis.
68
"""
79

8-
from typing import Dict, List, Optional, Tuple, Union, Any
10+
from typing import Dict, Optional, Tuple, Union, Any
911
import numpy as np
1012
import pandas as pd
1113
import matplotlib.pyplot as plt
@@ -217,7 +219,8 @@ def plot_covariate_effect(
217219
ci_alpha: float = 0.2,
218220
) -> Tuple[Figure, Axes]:
219221
"""
220-
Visualize the effect of a continuous covariate on survival by discretizing it.
222+
Visualize the effect of a continuous covariate on survival by discretizing
223+
it.
221224
222225
Parameters
223226
----------

tests/test_aft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""
22
Tests for Accelerated Failure Time (AFT) models.
33
"""
4-
4+
import os
5+
import sys
56
import pandas as pd
67
import pytest
78
from hypothesis import given, strategies as st
89

10+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
911
from gen_surv.aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic
1012

1113

tests/test_bivariate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import os
2+
import sys
13
import numpy as np
4+
5+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
26
from gen_surv.bivariate import sample_bivariate_distribution
37
import pytest
48

tests/test_cli.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
import sys
2+
import os
3+
import runpy
4+
15
import pandas as pd
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
28
from gen_surv.cli import dataset
3-
import runpy
49

510

611
def test_cli_dataset_stdout(monkeypatch, capsys):
7-
"""Dataset command prints CSV to stdout when no output file is given."""
12+
"""
13+
Test that the 'dataset' CLI command prints the generated CSV data to stdout when no output file is specified.
14+
This test patches the 'generate' function to return a simple DataFrame, invokes the CLI command directly,
15+
and asserts that the expected CSV header appears in the captured standard output.
16+
"""
817

918
def fake_generate(model: str, n: int):
1019
return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]})

0 commit comments

Comments
 (0)