|
21 | 21 | NelsonAalenFitter, BreslowFlemingHarringtonFitter, ExponentialFitter, \ |
22 | 22 | WeibullFitter, BaseFitter |
23 | 23 | from lifelines.datasets import load_larynx, load_waltons, load_kidney_transplant, load_rossi,\ |
24 | | - load_lcd, load_panel_test, load_g3, load_holly_molly_polly |
| 24 | + load_lcd, load_panel_test, load_g3, load_holly_molly_polly, load_regression_dataset |
25 | 25 | from lifelines.generate_datasets import generate_hazard_rates, generate_random_lifetimes, cumulative_integral |
26 | 26 | from lifelines.utils import concordance_index |
27 | 27 |
|
@@ -93,6 +93,10 @@ def rossi(): |
93 | 93 | return load_rossi() |
94 | 94 |
|
95 | 95 |
|
| 96 | +@pytest.fixture |
| 97 | +def regression_dataset(): |
| 98 | + return load_regression_dataset() |
| 99 | + |
96 | 100 |
|
97 | 101 | class TestBaseFitter(): |
98 | 102 |
|
@@ -683,16 +687,14 @@ def test_fit_method(self, data_nus): |
683 | 687 | assert np.abs(cf.hazards_.ix[0][0] - -0.0335) < 0.0001 |
684 | 688 |
|
685 | 689 | def test_using_dataframes_vs_numpy_arrays(self, data_pred2): |
686 | | - # First without normalization |
687 | 690 | cf = CoxPHFitter() |
688 | 691 | cf.fit(data_pred2, 't', 'E') |
689 | 692 |
|
690 | 693 | X = data_pred2[cf.data.columns] |
691 | | - hazards = cf.predict_partial_hazard(X) |
692 | | - |
693 | | - # A Numpy array should return the same result |
694 | | - hazards_n = cf.predict_partial_hazard(np.array(X)) |
695 | | - assert np.all(hazards == hazards_n) |
| 694 | + assert_frame_equal( |
| 695 | + cf.predict_partial_hazard(np.array(X)), |
| 696 | + cf.predict_partial_hazard(X) |
| 697 | + ) |
696 | 698 |
|
697 | 699 | def test_data_normalization(self, data_pred2): |
698 | 700 | # During fit, CoxPH copies the training data and normalizes it. |
@@ -925,12 +927,79 @@ def test_hazard_works_as_intended_with_strata_against_R_output(self, rossi): |
925 | 927 | npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 0)].ix[[14, 35, 37, 43, 52]].values, [0.076600555, 0.169748261, 0.272088807, 0.396562717, 0.396562717], decimal=2) |
926 | 928 | npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 1)].ix[[27, 43, 48, 52]].values, [0.095499001, 0.204196905, 0.338393113, 0.338393113], decimal=2) |
927 | 929 |
|
| 930 | + def test_baseline_survival_is_the_same_indp_of_location(self, regression_dataset): |
| 931 | + df = regression_dataset.copy() |
| 932 | + cp1 = CoxPHFitter() |
| 933 | + cp1.fit(df, event_col='E', duration_col='T') |
| 934 | + |
| 935 | + df_demeaned = regression_dataset.copy() |
| 936 | + df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean() |
| 937 | + cp2 = CoxPHFitter() |
| 938 | + cp2.fit(df_demeaned, event_col='E', duration_col='T') |
| 939 | + assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_) |
| 940 | + |
| 941 | + def test_baseline_cumulative_hazard_is_the_same_indp_of_location(self, regression_dataset): |
| 942 | + df = regression_dataset.copy() |
| 943 | + cp1 = CoxPHFitter() |
| 944 | + cp1.fit(df, event_col='E', duration_col='T') |
| 945 | + |
| 946 | + df_demeaned = regression_dataset.copy() |
| 947 | + df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean() |
| 948 | + cp2 = CoxPHFitter() |
| 949 | + cp2.fit(df_demeaned, event_col='E', duration_col='T') |
| 950 | + assert_frame_equal(cp2.baseline_cumulative_hazard_, cp1.baseline_cumulative_hazard_) |
| 951 | + |
| 952 | + def test_survival_prediction_is_the_same_indp_of_location(self, regression_dataset): |
| 953 | + df = regression_dataset.copy() |
| 954 | + |
| 955 | + df_demeaned = regression_dataset.copy() |
| 956 | + mean = df_demeaned[['var1', 'var2', 'var3']].mean() |
| 957 | + df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - mean |
| 958 | + |
| 959 | + cp1 = CoxPHFitter() |
| 960 | + cp1.fit(df, event_col='E', duration_col='T') |
| 961 | + |
| 962 | + cp2 = CoxPHFitter() |
| 963 | + cp2.fit(df_demeaned, event_col='E', duration_col='T') |
| 964 | + |
| 965 | + assert_frame_equal( |
| 966 | + cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]), |
| 967 | + cp2.predict_survival_function(df_demeaned.ix[[0]][['var1', 'var2', 'var3']]) |
| 968 | + ) |
| 969 | + |
| 970 | + def test_baseline_survival_is_the_same_indp_of_scale(self, regression_dataset): |
| 971 | + df = regression_dataset.copy() |
| 972 | + cp1 = CoxPHFitter() |
| 973 | + cp1.fit(df, event_col='E', duration_col='T') |
| 974 | + |
| 975 | + df_descaled = regression_dataset.copy() |
| 976 | + df_descaled[['var1', 'var2', 'var3']] = df_descaled[['var1', 'var2', 'var3']] / df_descaled[['var1', 'var2', 'var3']].std() |
| 977 | + cp2 = CoxPHFitter() |
| 978 | + cp2.fit(df_descaled, event_col='E', duration_col='T') |
| 979 | + assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_) |
| 980 | + |
| 981 | + def test_survival_prediction_is_the_same_indp_of_scale(self, regression_dataset): |
| 982 | + df = regression_dataset.copy() |
| 983 | + |
| 984 | + df_scaled = regression_dataset.copy() |
| 985 | + df_scaled[['var1', 'var2', 'var3']] = df_scaled[['var1', 'var2', 'var3']] * 10.0 |
| 986 | + |
| 987 | + cp1 = CoxPHFitter() |
| 988 | + cp1.fit(df, event_col='E', duration_col='T') |
| 989 | + |
| 990 | + cp2 = CoxPHFitter() |
| 991 | + cp2.fit(df_scaled, event_col='E', duration_col='T') |
| 992 | + |
| 993 | + assert_frame_equal( |
| 994 | + cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]), |
| 995 | + cp2.predict_survival_function(df_scaled.ix[[0]][['var1', 'var2', 'var3']]) |
| 996 | + ) |
| 997 | + |
928 | 998 | def test_predict_log_hazard_relative_to_mean(self, rossi): |
929 | 999 | cox = CoxPHFitter() |
930 | 1000 | cox.fit(rossi, 'week', 'arrest') |
931 | 1001 | log_relative_hazards = cox.predict_log_hazard_relative_to_mean(rossi) |
932 | 1002 | means = rossi.mean(0).to_frame().T |
933 | | - assert cox.predict_partial_hazard(means).values[0][0] != 1.0 |
934 | 1003 | assert_frame_equal(log_relative_hazards, np.log(cox.predict_partial_hazard(rossi) / cox.predict_partial_hazard(means).squeeze())) |
935 | 1004 |
|
936 | 1005 | def test_warning_is_raised_if_df_has_a_near_constant_column(self, rossi): |
|
0 commit comments