Skip to content

Commit 93b7f1e

Browse files
committed
chore: solve issues with formatting
1 parent 79f6a8b commit 93b7f1e

File tree

1 file changed

+54
-42
lines changed

1 file changed

+54
-42
lines changed

tests/test_aft.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Tests for Accelerated Failure Time (AFT) models.
33
"""
4+
45
import os
56
import sys
67
import pandas as pd
@@ -20,7 +21,7 @@ def test_gen_aft_log_logistic_runs():
2021
scale=2.0,
2122
model_cens="uniform",
2223
cens_par=5.0,
23-
seed=42
24+
seed=42,
2425
)
2526
assert isinstance(df, pd.DataFrame)
2627
assert not df.empty
@@ -32,38 +33,45 @@ def test_gen_aft_log_logistic_runs():
3233

3334

3435
def test_gen_aft_log_logistic_invalid_shape():
35-
"""Test that the Log-Logistic AFT generator raises error for invalid shape."""
36+
"""Test that the Log-Logistic AFT generator raises error
37+
for invalid shape."""
3638
with pytest.raises(ValueError, match="shape parameter must be positive"):
3739
gen_aft_log_logistic(
3840
n=10,
3941
beta=[0.5, -0.2],
4042
shape=-1.0, # Invalid negative shape
4143
scale=2.0,
4244
model_cens="uniform",
43-
cens_par=5.0
45+
cens_par=5.0,
4446
)
4547

4648

4749
def test_gen_aft_log_logistic_invalid_scale():
48-
"""Test that the Log-Logistic AFT generator raises error for invalid scale."""
50+
"""Test that the Log-Logistic AFT generator raises error
51+
for invalid scale."""
4952
with pytest.raises(ValueError, match="scale parameter must be positive"):
5053
gen_aft_log_logistic(
5154
n=10,
5255
beta=[0.5, -0.2],
5356
shape=1.5,
5457
scale=0.0, # Invalid zero scale
5558
model_cens="uniform",
56-
cens_par=5.0
59+
cens_par=5.0,
5760
)
5861

5962

60-
6163
@given(
6264
n=st.integers(min_value=1, max_value=20),
63-
shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False),
64-
scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False),
65-
cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False),
66-
seed=st.integers(min_value=0, max_value=1000)
65+
shape=st.floats(
66+
min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False
67+
),
68+
scale=st.floats(
69+
min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False
70+
),
71+
cens_par=st.floats(
72+
min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False
73+
),
74+
seed=st.integers(min_value=0, max_value=1000),
6775
)
6876
def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed):
6977
"""Property-based test for the Log-Logistic AFT generator."""
@@ -74,7 +82,7 @@ def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed):
7482
scale=scale,
7583
model_cens="uniform",
7684
cens_par=cens_par,
77-
seed=seed
85+
seed=seed,
7886
)
7987
assert df.shape[0] == n
8088
assert set(df["status"].unique()).issubset({0, 1})
@@ -83,52 +91,48 @@ def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed):
8391

8492

8593
def test_gen_aft_log_logistic_reproducibility():
86-
"""Test that the Log-Logistic AFT generator is reproducible with the same seed."""
94+
"""Test that the Log-Logistic AFT generator is reproducible
95+
with the same seed."""
8796
df1 = gen_aft_log_logistic(
8897
n=10,
8998
beta=[0.5, -0.2],
9099
shape=1.5,
91100
scale=2.0,
92101
model_cens="uniform",
93102
cens_par=5.0,
94-
seed=42
103+
seed=42,
95104
)
96-
105+
97106
df2 = gen_aft_log_logistic(
98107
n=10,
99108
beta=[0.5, -0.2],
100109
shape=1.5,
101110
scale=2.0,
102111
model_cens="uniform",
103112
cens_par=5.0,
104-
seed=42
113+
seed=42,
105114
)
106-
115+
107116
pd.testing.assert_frame_equal(df1, df2)
108-
117+
109118
df3 = gen_aft_log_logistic(
110119
n=10,
111120
beta=[0.5, -0.2],
112121
shape=1.5,
113122
scale=2.0,
114123
model_cens="uniform",
115124
cens_par=5.0,
116-
seed=43 # Different seed
125+
seed=43, # Different seed
117126
)
118-
127+
119128
with pytest.raises(AssertionError):
120129
pd.testing.assert_frame_equal(df1, df3)
121130

122131

123132
def test_gen_aft_log_normal_runs():
124133
"""Test that the log-normal AFT generator runs without errors."""
125134
df = gen_aft_log_normal(
126-
n=10,
127-
beta=[0.5, -0.2],
128-
sigma=1.0,
129-
model_cens="uniform",
130-
cens_par=5.0,
131-
seed=42
135+
n=10, beta=[0.5, -0.2], sigma=1.0, model_cens="uniform", cens_par=5.0, seed=42
132136
)
133137
assert isinstance(df, pd.DataFrame)
134138
assert not df.empty
@@ -148,7 +152,7 @@ def test_gen_aft_weibull_runs():
148152
scale=2.0,
149153
model_cens="uniform",
150154
cens_par=5.0,
151-
seed=42
155+
seed=42,
152156
)
153157
assert isinstance(df, pd.DataFrame)
154158
assert not df.empty
@@ -168,7 +172,7 @@ def test_gen_aft_weibull_invalid_shape():
168172
shape=-1.0, # Invalid negative shape
169173
scale=2.0,
170174
model_cens="uniform",
171-
cens_par=5.0
175+
cens_par=5.0,
172176
)
173177

174178

@@ -181,29 +185,37 @@ def test_gen_aft_weibull_invalid_scale():
181185
shape=1.5,
182186
scale=0.0, # Invalid zero scale
183187
model_cens="uniform",
184-
cens_par=5.0
188+
cens_par=5.0,
185189
)
186190

187191

188192
def test_gen_aft_weibull_invalid_cens_model():
189193
"""Test that the Weibull AFT generator raises error for invalid censoring model."""
190-
with pytest.raises(ValueError, match="model_cens must be 'uniform' or 'exponential'"):
194+
with pytest.raises(
195+
ValueError, match="model_cens must be 'uniform' or 'exponential'"
196+
):
191197
gen_aft_weibull(
192198
n=10,
193199
beta=[0.5, -0.2],
194200
shape=1.5,
195201
scale=2.0,
196202
model_cens="invalid", # Invalid censoring model
197-
cens_par=5.0
203+
cens_par=5.0,
198204
)
199205

200206

201207
@given(
202208
n=st.integers(min_value=1, max_value=20),
203-
shape=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False),
204-
scale=st.floats(min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False),
205-
cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False),
206-
seed=st.integers(min_value=0, max_value=1000)
209+
shape=st.floats(
210+
min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False
211+
),
212+
scale=st.floats(
213+
min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False
214+
),
215+
cens_par=st.floats(
216+
min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False
217+
),
218+
seed=st.integers(min_value=0, max_value=1000),
207219
)
208220
def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed):
209221
"""Property-based test for the Weibull AFT generator."""
@@ -214,7 +226,7 @@ def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed):
214226
scale=scale,
215227
model_cens="uniform",
216228
cens_par=cens_par,
217-
seed=seed
229+
seed=seed,
218230
)
219231
assert df.shape[0] == n
220232
assert set(df["status"].unique()).issubset({0, 1})
@@ -231,30 +243,30 @@ def test_gen_aft_weibull_reproducibility():
231243
scale=2.0,
232244
model_cens="uniform",
233245
cens_par=5.0,
234-
seed=42
246+
seed=42,
235247
)
236-
248+
237249
df2 = gen_aft_weibull(
238250
n=10,
239251
beta=[0.5, -0.2],
240252
shape=1.5,
241253
scale=2.0,
242254
model_cens="uniform",
243255
cens_par=5.0,
244-
seed=42
256+
seed=42,
245257
)
246-
258+
247259
pd.testing.assert_frame_equal(df1, df2)
248-
260+
249261
df3 = gen_aft_weibull(
250262
n=10,
251263
beta=[0.5, -0.2],
252264
shape=1.5,
253265
scale=2.0,
254266
model_cens="uniform",
255267
cens_par=5.0,
256-
seed=43 # Different seed
268+
seed=43, # Different seed
257269
)
258-
270+
259271
with pytest.raises(AssertionError):
260272
pd.testing.assert_frame_equal(df1, df3)

0 commit comments

Comments
 (0)