Skip to content

Commit 8e3272a

Browse files
Merge pull request #295 from CamDavidsonPilon/fix-scaling-in-cox-ph-prediction
Fix scaling in cox ph prediction
2 parents 0fe0d10 + a54a3c9 commit 8e3272a

File tree

4 files changed

+83
-10
lines changed

4 files changed

+83
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
### Changelogs
22

3+
#### 0.10.1
4+
- fix in internal normalization for `CoxPHFitter` predict methods.
5+
36
#### 0.10.0
47
- corrected bug that was returning the wrong baseline survival and hazard values in `CoxPHFitter` when `normalize=True`.
58
- removed `normalize` kwarg in `CoxPHFitter`. This was causing lots of confusion for users, and added code complexity. It's really nice to be able to remove it.

lifelines/fitters/coxph_fitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def fit(self, df, duration_col, event_col=None,
312312
self.durations = T
313313
self.event_observed = E
314314

315-
self.baseline_hazard_ = self._compute_baseline_hazards(normalize(df, 0, 1 / self._norm_std), T, E)
315+
self.baseline_hazard_ = self._compute_baseline_hazards(df * self._norm_std + self._norm_mean, T, E)
316316
self.baseline_cumulative_hazard_ = self.baseline_hazard_.cumsum()
317317
self.baseline_survival_ = self._compute_baseline_survival()
318318
return self
@@ -427,6 +427,7 @@ def predict_log_partial_hazard(self, X):
427427
X = X[order]
428428

429429
index = _get_index(X)
430+
X = normalize(X, self._norm_mean.values, 1)
430431
return pd.DataFrame(np.dot(X, self.hazards_.T), index=index)
431432

432433
def predict_log_hazard_relative_to_mean(self, X):

lifelines/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import unicode_literals
22

3-
__version__ = '0.10.0'
3+
__version__ = '0.10.1'

tests/test_estimation.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
NelsonAalenFitter, BreslowFlemingHarringtonFitter, ExponentialFitter, \
2222
WeibullFitter, BaseFitter
2323
from lifelines.datasets import load_larynx, load_waltons, load_kidney_transplant, load_rossi,\
24-
load_lcd, load_panel_test, load_g3, load_holly_molly_polly
24+
load_lcd, load_panel_test, load_g3, load_holly_molly_polly, load_regression_dataset
2525
from lifelines.generate_datasets import generate_hazard_rates, generate_random_lifetimes, cumulative_integral
2626
from lifelines.utils import concordance_index
2727

@@ -93,6 +93,10 @@ def rossi():
9393
return load_rossi()
9494

9595

96+
@pytest.fixture
97+
def regression_dataset():
98+
return load_regression_dataset()
99+
96100

97101
class TestBaseFitter():
98102

@@ -683,16 +687,14 @@ def test_fit_method(self, data_nus):
683687
assert np.abs(cf.hazards_.ix[0][0] - -0.0335) < 0.0001
684688

685689
def test_using_dataframes_vs_numpy_arrays(self, data_pred2):
686-
# First without normalization
687690
cf = CoxPHFitter()
688691
cf.fit(data_pred2, 't', 'E')
689692

690693
X = data_pred2[cf.data.columns]
691-
hazards = cf.predict_partial_hazard(X)
692-
693-
# A Numpy array should return the same result
694-
hazards_n = cf.predict_partial_hazard(np.array(X))
695-
assert np.all(hazards == hazards_n)
694+
assert_frame_equal(
695+
cf.predict_partial_hazard(np.array(X)),
696+
cf.predict_partial_hazard(X)
697+
)
696698

697699
def test_data_normalization(self, data_pred2):
698700
# During fit, CoxPH copies the training data and normalizes it.
@@ -925,12 +927,79 @@ def test_hazard_works_as_intended_with_strata_against_R_output(self, rossi):
925927
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 0)].ix[[14, 35, 37, 43, 52]].values, [0.076600555, 0.169748261, 0.272088807, 0.396562717, 0.396562717], decimal=2)
926928
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 1)].ix[[27, 43, 48, 52]].values, [0.095499001, 0.204196905, 0.338393113, 0.338393113], decimal=2)
927929

930+
def test_baseline_survival_is_the_same_indp_of_location(self, regression_dataset):
931+
df = regression_dataset.copy()
932+
cp1 = CoxPHFitter()
933+
cp1.fit(df, event_col='E', duration_col='T')
934+
935+
df_demeaned = regression_dataset.copy()
936+
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean()
937+
cp2 = CoxPHFitter()
938+
cp2.fit(df_demeaned, event_col='E', duration_col='T')
939+
assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_)
940+
941+
def test_baseline_cumulative_hazard_is_the_same_indp_of_location(self, regression_dataset):
942+
df = regression_dataset.copy()
943+
cp1 = CoxPHFitter()
944+
cp1.fit(df, event_col='E', duration_col='T')
945+
946+
df_demeaned = regression_dataset.copy()
947+
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean()
948+
cp2 = CoxPHFitter()
949+
cp2.fit(df_demeaned, event_col='E', duration_col='T')
950+
assert_frame_equal(cp2.baseline_cumulative_hazard_, cp1.baseline_cumulative_hazard_)
951+
952+
def test_survival_prediction_is_the_same_indp_of_location(self, regression_dataset):
953+
df = regression_dataset.copy()
954+
955+
df_demeaned = regression_dataset.copy()
956+
mean = df_demeaned[['var1', 'var2', 'var3']].mean()
957+
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - mean
958+
959+
cp1 = CoxPHFitter()
960+
cp1.fit(df, event_col='E', duration_col='T')
961+
962+
cp2 = CoxPHFitter()
963+
cp2.fit(df_demeaned, event_col='E', duration_col='T')
964+
965+
assert_frame_equal(
966+
cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]),
967+
cp2.predict_survival_function(df_demeaned.ix[[0]][['var1', 'var2', 'var3']])
968+
)
969+
970+
def test_baseline_survival_is_the_same_indp_of_scale(self, regression_dataset):
971+
df = regression_dataset.copy()
972+
cp1 = CoxPHFitter()
973+
cp1.fit(df, event_col='E', duration_col='T')
974+
975+
df_descaled = regression_dataset.copy()
976+
df_descaled[['var1', 'var2', 'var3']] = df_descaled[['var1', 'var2', 'var3']] / df_descaled[['var1', 'var2', 'var3']].std()
977+
cp2 = CoxPHFitter()
978+
cp2.fit(df_descaled, event_col='E', duration_col='T')
979+
assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_)
980+
981+
def test_survival_prediction_is_the_same_indp_of_scale(self, regression_dataset):
982+
df = regression_dataset.copy()
983+
984+
df_scaled = regression_dataset.copy()
985+
df_scaled[['var1', 'var2', 'var3']] = df_scaled[['var1', 'var2', 'var3']] * 10.0
986+
987+
cp1 = CoxPHFitter()
988+
cp1.fit(df, event_col='E', duration_col='T')
989+
990+
cp2 = CoxPHFitter()
991+
cp2.fit(df_scaled, event_col='E', duration_col='T')
992+
993+
assert_frame_equal(
994+
cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]),
995+
cp2.predict_survival_function(df_scaled.ix[[0]][['var1', 'var2', 'var3']])
996+
)
997+
928998
def test_predict_log_hazard_relative_to_mean(self, rossi):
929999
cox = CoxPHFitter()
9301000
cox.fit(rossi, 'week', 'arrest')
9311001
log_relative_hazards = cox.predict_log_hazard_relative_to_mean(rossi)
9321002
means = rossi.mean(0).to_frame().T
933-
assert cox.predict_partial_hazard(means).values[0][0] != 1.0
9341003
assert_frame_equal(log_relative_hazards, np.log(cox.predict_partial_hazard(rossi) / cox.predict_partial_hazard(means).squeeze()))
9351004

9361005
def test_warning_is_raised_if_df_has_a_near_constant_column(self, rossi):

0 commit comments

Comments
 (0)