55using the gen_surv package.
66"""
77
8- from typing import List , Optional , TypeVar , cast
8+ from typing import Any , Dict , List , Optional , TypeVar , cast
99
1010import typer
1111
@@ -91,8 +91,8 @@ def _val(v: T | OptionInfo) -> T:
9191 return v if not isinstance (v , OptionInfo ) else cast (T , v .default )
9292
9393 # Prepare arguments based on the selected model
94- model_str = _val (model )
95- kwargs = {
94+ model_str : str = _val (model )
95+ kwargs : Dict [ str , Any ] = {
9696 "model" : model_str ,
9797 "n" : _val (n ),
9898 "model_cens" : _val (model_cens ),
@@ -103,7 +103,8 @@ def _val(v: T | OptionInfo) -> T:
103103 # Add model-specific parameters
104104 if model_str in ["cphm" , "cmm" , "thmm" ]:
105105 # These models use a single beta and covariate range
106- kwargs ["beta" ] = _val (beta )[0 ] if len (_val (beta )) > 0 else 0.5
106+ beta_values = cast (List [float ], _val (beta ))
107+ kwargs ["beta" ] = beta_values [0 ] if len (beta_values ) > 0 else 0.5
107108 kwargs ["covariate_range" ] = _val (covariate_range )
108109
109110 elif model_str == "aft_ln" :
@@ -153,10 +154,10 @@ def _val(v: T | OptionInfo) -> T:
153154
154155 # Generate the data
155156 try :
156- df = generate (** kwargs )
157+ df = generate (** kwargs ) # type: ignore[arg-type]
157158 except TypeError :
158159 # Fallback for tests where generate accepts only model and n
159- df = generate (model = model_str , n = _val (n ))
160+ df = generate (model = model_str , n = _val (n )) # type: ignore[arg-type]
160161
161162 # Output the data
162163 if output :
@@ -228,6 +229,7 @@ def visualize(
228229
229230 # Save the plot
230231 plt .savefig (output , dpi = 300 , bbox_inches = "tight" )
232+ plt .close (fig )
231233 typer .echo (f"Plot saved to { output } " )
232234
233235
0 commit comments