Skip to content

Commit 9cc2aa1

Browse files
adding test for it too
1 parent e6e9350 commit 9cc2aa1

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

lifelines/tests/test_suite.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,24 @@ class AalenRegressionTests(unittest.TestCase):
551551
def setUp(self):
552552
self.aaf = AalenAdditiveFitter(penalizer=0.1, fit_intercept=False)
553553

554+
def test_input_column_order_is_equal_to_output_hazards_order(self):
555+
rossi = load_rossi()
556+
aaf = AalenAdditiveFitter()
557+
expected = ['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']
558+
aaf.fit(rossi, event_col='week', duration_col='arrest')
559+
assert list(aaf.cumulative_hazards_.columns.drop('baseline')) == expected
560+
561+
def test_swapping_order_of_columns_in_a_df_is_okay(self):
562+
rossi = load_rossi()
563+
aaf = AalenAdditiveFitter()
564+
aaf.fit(rossi, event_col='week', duration_col='arrest')
565+
566+
misorder = ['age', 'race', 'wexp', 'mar', 'paro', 'prio', 'fin']
567+
natural_order = rossi.columns.drop(['week','arrest'])
568+
deleted_order = rossi.columns - ['week','arrest']
569+
assert_frame_equal(aaf.predict_median(rossi[natural_order]), aaf.predict_median(rossi[misorder]))
570+
assert_frame_equal(aaf.predict_median(rossi[natural_order]), aaf.predict_median(rossi[deleted_order]))
571+
554572
def test_large_dimensions_for_recursion_error(self):
555573
n = 500
556574
d = 50
@@ -899,6 +917,13 @@ def test_flat_style_no_censor(self):
899917

900918
class CoxRegressionTests(unittest.TestCase):
901919

920+
def test_input_column_order_is_equal_to_output_hazards_order(self):
921+
rossi = load_rossi()
922+
cp = CoxPHFitter()
923+
expected = ['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']
924+
cp.fit(rossi, event_col='week', duration_col='arrest')
925+
assert list(cp.hazards_.columns) == expected
926+
902927
def test_log_likelihood_is_available_in_output(self):
903928
cox = CoxPHFitter()
904929
cox.fit(data_nus, duration_col='t', event_col='E', include_likelihood=True)

0 commit comments

Comments
 (0)