@@ -42,6 +42,31 @@ def test_validate_gen_cmm_inputs_invalid_beta_length():
4242 )
4343
4444
45+ @pytest .mark .parametrize (
46+ "n, model_cens, cens_par, cov_range, rate" ,
47+ [
48+ (0 , "uniform" , 0.5 , 1.0 , [0.1 ] * 6 ),
49+ (1 , "bad" , 0.5 , 1.0 , [0.1 ] * 6 ),
50+ (1 , "uniform" , 0.0 , 1.0 , [0.1 ] * 6 ),
51+ (1 , "uniform" , 0.5 , 0.0 , [0.1 ] * 6 ),
52+ (1 , "uniform" , 0.5 , 1.0 , [0.1 ] * 3 ),
53+ ],
54+ )
55+ def test_validate_gen_cmm_inputs_other_invalid (
56+ n , model_cens , cens_par , cov_range , rate
57+ ):
58+ with pytest .raises (ValueError ):
59+ v .validate_gen_cmm_inputs (
60+ n , model_cens , cens_par , [0.1 , 0.2 , 0.3 ], cov_range , rate
61+ )
62+
63+
64+ def test_validate_gen_cmm_inputs_valid ():
65+ v .validate_gen_cmm_inputs (
66+ 1 , "uniform" , 1.0 , [0.1 , 0.2 , 0.3 ], covariate_range = 1.0 , rate = [0.1 ] * 6
67+ )
68+
69+
4570def test_validate_gen_tdcm_inputs_invalid_lambda ():
4671 """Lambda <= 0 should raise a ValueError."""
4772 with pytest .raises (ValueError ):
@@ -57,6 +82,44 @@ def test_validate_gen_tdcm_inputs_invalid_lambda():
5782 )
5883
5984
85+ @pytest .mark .parametrize (
86+ "dist,corr,dist_par" ,
87+ [
88+ ("bad" , 0.5 , [1 , 2 ]),
89+ ("weibull" , 0.0 , [1 , 2 , 3 , 4 ]),
90+ ("weibull" , 0.5 , [1 , 2 , - 1 , 2 ]),
91+ ("weibull" , 0.5 , [1 , 2 , 3 ]),
92+ ("exponential" , 2.0 , [1 , 1 ]),
93+ ("exponential" , 0.5 , [1 ]),
94+ ],
95+ )
96+ def test_validate_gen_tdcm_inputs_invalid_dist (dist , corr , dist_par ):
97+ with pytest .raises (ValueError ):
98+ v .validate_gen_tdcm_inputs (
99+ 1 ,
100+ dist ,
101+ corr ,
102+ dist_par ,
103+ "uniform" ,
104+ 1.0 ,
105+ beta = [0.1 , 0.2 , 0.3 ],
106+ lam = 1.0 ,
107+ )
108+
109+
110+ def test_validate_gen_tdcm_inputs_valid ():
111+ v .validate_gen_tdcm_inputs (
112+ 1 ,
113+ "weibull" ,
114+ 0.5 ,
115+ [1 , 1 , 1 , 1 ],
116+ "uniform" ,
117+ 1.0 ,
118+ beta = [0.1 , 0.2 , 0.3 ],
119+ lam = 1.0 ,
120+ )
121+
122+
60123def test_validate_gen_aft_log_normal_inputs_valid ():
61124 """Valid parameters should not raise an error for AFT log-normal."""
62125 v .validate_gen_aft_log_normal_inputs (
@@ -68,19 +131,84 @@ def test_validate_gen_aft_log_normal_inputs_valid():
68131 )
69132
70133
134+ @pytest .mark .parametrize (
135+ "n,beta,sigma,model_cens,cens_par" ,
136+ [
137+ (0 , [0.1 ], 1.0 , "uniform" , 1.0 ),
138+ (1 , "bad" , 1.0 , "uniform" , 1.0 ),
139+ (1 , [0.1 ], 0.0 , "uniform" , 1.0 ),
140+ (1 , [0.1 ], 1.0 , "bad" , 1.0 ),
141+ (1 , [0.1 ], 1.0 , "uniform" , 0.0 ),
142+ ],
143+ )
144+ def test_validate_gen_aft_log_normal_inputs_invalid (
145+ n , beta , sigma , model_cens , cens_par
146+ ):
147+ with pytest .raises (ValueError ):
148+ v .validate_gen_aft_log_normal_inputs (n , beta , sigma , model_cens , cens_par )
149+
150+
71151def test_validate_dg_biv_inputs_valid_weibull ():
72152 """Valid parameters for a Weibull distribution should pass."""
73153 v .validate_dg_biv_inputs (5 , "weibull" , 0.1 , [1.0 , 1.0 , 1.0 , 1.0 ])
74154
75155
156+ def test_validate_dg_biv_inputs_invalid_corr_and_params ():
157+ with pytest .raises (ValueError ):
158+ v .validate_dg_biv_inputs (1 , "exponential" , - 2.0 , [1.0 , 1.0 ])
159+ with pytest .raises (ValueError ):
160+ v .validate_dg_biv_inputs (1 , "exponential" , 0.5 , [1.0 ])
161+ with pytest .raises (ValueError ):
162+ v .validate_dg_biv_inputs (1 , "weibull" , 0.5 , [1.0 , 1.0 ])
163+
164+
76165def test_validate_gen_aft_weibull_inputs_and_log_logistic ():
77166 with pytest .raises (ValueError ):
78167 v .validate_gen_aft_weibull_inputs (0 , [0.1 ], 1.0 , 1.0 , "uniform" , 1.0 )
79168 with pytest .raises (ValueError ):
80169 v .validate_gen_aft_log_logistic_inputs (1 , [0.1 ], - 1.0 , 1.0 , "uniform" , 1.0 )
81170
82171
172+ @pytest .mark .parametrize (
173+ "shape,scale" ,
174+ [(- 1.0 , 1.0 ), (1.0 , - 1.0 )],
175+ )
176+ def test_validate_gen_aft_weibull_invalid_params (shape , scale ):
177+ with pytest .raises (ValueError ):
178+ v .validate_gen_aft_weibull_inputs (1 , [0.1 ], shape , scale , "uniform" , 1.0 )
179+
180+
181+ def test_validate_gen_aft_weibull_valid ():
182+ v .validate_gen_aft_weibull_inputs (1 , [0.1 ], 1.0 , 1.0 , "uniform" , 1.0 )
183+
184+
185+ def test_validate_gen_aft_log_logistic_valid ():
186+ v .validate_gen_aft_log_logistic_inputs (1 , [0.1 ], 1.0 , 1.0 , "uniform" , 1.0 )
187+
188+
83189def test_validate_competing_risks_inputs ():
84190 with pytest .raises (ValueError ):
85191 v .validate_competing_risks_inputs (1 , 2 , [0.1 ], None , "uniform" , 1.0 )
86192 v .validate_competing_risks_inputs (1 , 1 , [0.5 ], [[0.1 ]], "uniform" , 0.5 )
193+
194+
195+ @pytest .mark .parametrize (
196+ "n,model_cens,cens_par,beta,cov_range,rate" ,
197+ [
198+ (0 , "uniform" , 1.0 , [0.1 , 0.2 , 0.3 ], 1.0 , [0.1 , 0.2 , 0.3 ]),
199+ (1 , "bad" , 1.0 , [0.1 , 0.2 , 0.3 ], 1.0 , [0.1 , 0.2 , 0.3 ]),
200+ (1 , "uniform" , 0.0 , [0.1 , 0.2 , 0.3 ], 1.0 , [0.1 , 0.2 , 0.3 ]),
201+ (1 , "uniform" , 1.0 , [0.1 , 0.2 ], 1.0 , [0.1 , 0.2 , 0.3 ]),
202+ (1 , "uniform" , 1.0 , [0.1 , 0.2 , 0.3 ], 0.0 , [0.1 , 0.2 , 0.3 ]),
203+ (1 , "uniform" , 1.0 , [0.1 , 0.2 , 0.3 ], 1.0 , [0.1 ]),
204+ ],
205+ )
206+ def test_validate_gen_thmm_inputs_invalid (
207+ n , model_cens , cens_par , beta , cov_range , rate
208+ ):
209+ with pytest .raises (ValueError ):
210+ v .validate_gen_thmm_inputs (n , model_cens , cens_par , beta , cov_range , rate )
211+
212+
213+ def test_validate_gen_thmm_inputs_valid ():
214+ v .validate_gen_thmm_inputs (1 , "uniform" , 1.0 , [0.1 , 0.2 , 0.3 ], 1.0 , [0.1 , 0.2 , 0.3 ])
0 commit comments