1+ import pandas as pd
2+ import pytest
3+ import typer
4+
15from gen_surv import generate
2- from gen_surv .visualization import plot_survival_curve
6+ from gen_surv .cli import visualize
7+ from gen_surv .visualization import (
8+ describe_survival ,
9+ plot_covariate_effect ,
10+ plot_hazard_comparison ,
11+ plot_survival_curve ,
12+ )
313
414
515def test_plot_survival_curve_runs ():
@@ -14,3 +24,97 @@ def test_plot_survival_curve_runs():
1424 fig , ax = plot_survival_curve (df )
1525 assert fig is not None
1626 assert ax is not None
27+
28+
29+ def test_plot_hazard_comparison_runs ():
30+ df1 = generate (
31+ model = "cphm" ,
32+ n = 5 ,
33+ model_cens = "uniform" ,
34+ cens_par = 1.0 ,
35+ beta = 0.5 ,
36+ covariate_range = 1.0 ,
37+ )
38+ df2 = generate (
39+ model = "aft_weibull" ,
40+ n = 5 ,
41+ beta = [0.5 ],
42+ shape = 1.5 ,
43+ scale = 2.0 ,
44+ model_cens = "uniform" ,
45+ cens_par = 1.0 ,
46+ )
47+ models = {"cphm" : df1 , "aft_weibull" : df2 }
48+ fig , ax = plot_hazard_comparison (models )
49+ assert fig is not None
50+ assert ax is not None
51+
52+
53+ def test_plot_covariate_effect_runs ():
54+ df = generate (
55+ model = "cphm" ,
56+ n = 10 ,
57+ model_cens = "uniform" ,
58+ cens_par = 1.0 ,
59+ beta = 0.5 ,
60+ covariate_range = 2.0 ,
61+ )
62+ fig , ax = plot_covariate_effect (df , covariate_col = "X0" , n_groups = 2 )
63+ assert fig is not None
64+ assert ax is not None
65+
66+
67+ def test_describe_survival_summary ():
68+ df = generate (
69+ model = "cphm" ,
70+ n = 10 ,
71+ model_cens = "uniform" ,
72+ cens_par = 1.0 ,
73+ beta = 0.5 ,
74+ covariate_range = 2.0 ,
75+ )
76+ summary = describe_survival (df )
77+ expected_metrics = [
78+ "Total Observations" ,
79+ "Number of Events" ,
80+ "Number Censored" ,
81+ "Event Rate" ,
82+ "Median Survival Time" ,
83+ "Min Time" ,
84+ "Max Time" ,
85+ "Mean Time" ,
86+ ]
87+ assert list (summary ["Metric" ]) == expected_metrics
88+ assert summary .shape [0 ] == len (expected_metrics )
89+
90+
91+ def test_cli_visualize (tmp_path , capsys ):
92+ df = pd .DataFrame ({"time" : [1 , 2 , 3 ], "status" : [1 , 0 , 1 ]})
93+ csv_path = tmp_path / "d.csv"
94+ df .to_csv (csv_path , index = False )
95+ out_file = tmp_path / "out.png"
96+ visualize (
97+ str (csv_path ),
98+ time_col = "time" ,
99+ status_col = "status" ,
100+ group_col = None ,
101+ output = str (out_file ),
102+ )
103+ assert out_file .exists ()
104+ captured = capsys .readouterr ()
105+ assert "Plot saved to" in captured .out
106+
107+
108+ def test_cli_visualize_missing_column (tmp_path , capsys ):
109+ df = pd .DataFrame ({"time" : [1 , 2 ], "event" : [1 , 0 ]})
110+ csv_path = tmp_path / "bad.csv"
111+ df .to_csv (csv_path , index = False )
112+ with pytest .raises (typer .Exit ):
113+ visualize (
114+ str (csv_path ),
115+ time_col = "time" ,
116+ status_col = "status" ,
117+ group_col = None ,
118+ )
119+ captured = capsys .readouterr ()
120+ assert "Status column 'status' not found in data" in captured .out
0 commit comments