Skip to content

Commit d22c286

Browse files
Merge pull request #77 from CamDavidsonPilon/adding-conditional-time-to
Adding conditional time to
2 parents 38f3c09 + ffeb6ad commit d22c286

File tree

5 files changed

+129
-33
lines changed

5 files changed

+129
-33
lines changed

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
### Changelogs
22

3-
3+
####0.4.3
4+
- refactoring of `qth_survival_times`: it can now accept an iterable (or a scalar still) of probabilities in the q argument, and will return a DataFrame with these as columns. If len(q)==1 and a single survival function is given, will return a scalar, not a DataFrame. Also some good speed improvements.
5+
- KaplanMeierFitter and NelsonAalenFitter now have a `_label` property that is passed in during the fit.
6+
- KaplanMeierFitter/NelsonAalenFitter's inital `alpha` value is overwritten if a new `alpha` value is passed
7+
in during the `fit`.
8+
- New method for KaplanMeierFitter: `conditional_time_to`. This returns a DataFrame of the estimate:
9+
med(S(t | T>s)) - s, human readable: the estimated time left of living, given an individual is aged s.
10+
- Adds option `include_likelihood` to CoxPHFitter fit method to save the final log-likelihood value.
411

512
####0.4.2
613

lifelines/estimation.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None,
8080
self._additive_f, self._variance_f, False)
8181

8282
# esimates
83-
self.cumulative_hazard_ = pd.DataFrame(cumulative_hazard_, columns=[label])
83+
self._label = label
84+
self.cumulative_hazard_ = pd.DataFrame(cumulative_hazard_, columns=[self._label])
8485
self.confidence_interval_ = self._bounds(cumulative_sq_[:, None], alpha if alpha else self.alpha, ci_labels)
8586
self._cumulative_sq = cumulative_sq_
8687

8788
# estimation functions
88-
self.predict = _predict(self, "cumulative_hazard_", label)
89+
self.predict = _predict(self, "cumulative_hazard_", self._label)
8990
self.subtract = _subtract(self, "cumulative_hazard_")
9091
self.divide = _divide(self, "cumulative_hazard_")
9192

@@ -99,10 +100,9 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None,
99100
def _bounds(self, cumulative_sq_, alpha, ci_labels):
100101
alpha2 = inv_normal_cdf(1 - (1 - alpha) / 2)
101102
df = pd.DataFrame(index=self.timeline)
102-
name = self.cumulative_hazard_.columns[0]
103103

104104
if ci_labels is None:
105-
ci_labels = ["%s_upper_%.2f" % (name, self.alpha), "%s_lower_%.2f" % (name, self.alpha)]
105+
ci_labels = ["%s_upper_%.2f" % (self._label, self.alpha), "%s_lower_%.2f" % (self._label, self.alpha)]
106106
assert len(ci_labels) == 2, "ci_labels should be a length 2 array."
107107
self.ci_labels = ci_labels
108108

@@ -206,7 +206,8 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None, label='
206206

207207
v = preprocess_inputs(durations, event_observed, timeline, entry)
208208
self.durations, self.event_observed, self.timeline, self.entry, self.event_table = v
209-
209+
self._label = label
210+
self.alpha = alpha if alpha else self.alpha
210211
log_survival_function, cumulative_sq_ = _additive_estimate(self.event_table, self.timeline,
211212
self._additive_f, self._additive_var,
212213
left_censorship)
@@ -219,12 +220,12 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None, label='
219220
net_population = (self.event_table['entrance'] - self.event_table['removed']).cumsum()
220221
if net_population.iloc[:int(n / 2)].min() == 0:
221222
ix = net_population.iloc[:int(n / 2)].argmin()
222-
raise StatError("""There are too few early truncation times and too many events. S(t)==0 for all t>%.1f. Recommend BFH estimator.""" % ix)
223+
raise StatError("""There are too few early truncation times and too many events. S(t)==0 for all t>%.1f. Recommend BreslowFlemingHarringtonFitter.""" % ix)
223224

224225
# estimation
225-
setattr(self, estimate_name, pd.DataFrame(np.exp(log_survival_function), columns=[label]))
226+
setattr(self, estimate_name, pd.DataFrame(np.exp(log_survival_function), columns=[self._label]))
226227
self.__estimate = getattr(self, estimate_name)
227-
self.confidence_interval_ = self._bounds(cumulative_sq_[:, None], alpha if alpha else self.alpha, ci_labels)
228+
self.confidence_interval_ = self._bounds(cumulative_sq_[:, None], ci_labels)
228229
self.median_ = median_survival_times(self.__estimate)
229230

230231
# estimation methods
@@ -237,15 +238,14 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None, label='
237238
setattr(self, "plot_" + estimate_name, self.plot)
238239
return self
239240

240-
def _bounds(self, cumulative_sq_, alpha, ci_labels):
241+
def _bounds(self, cumulative_sq_, ci_labels):
241242
# See http://courses.nus.edu.sg/course/stacar/internet/st3242/handouts/notes2.pdfg
242-
alpha2 = inv_normal_cdf((1. + alpha) / 2.)
243+
alpha2 = inv_normal_cdf((1. + self.alpha) / 2.)
243244
df = pd.DataFrame(index=self.timeline)
244-
name = self.__estimate.columns[0]
245245
v = np.log(self.__estimate.values)
246246

247247
if ci_labels is None:
248-
ci_labels = ["%s_upper_%.2f" % (name, self.alpha), "%s_lower_%.2f" % (name, self.alpha)]
248+
ci_labels = ["%s_upper_%.2f" % (self._label, self.alpha), "%s_lower_%.2f" % (self._label, self.alpha)]
249249
assert len(ci_labels) == 2, "ci_labels should be a length 2 array."
250250

251251
df[ci_labels[0]] = np.exp(-np.exp(np.log(-v) + alpha2 * np.sqrt(cumulative_sq_) / v))
@@ -260,6 +260,23 @@ def _additive_var(self, population, deaths):
260260
np.seterr(divide='ignore')
261261
return (1. * deaths / (population * (population - deaths))).replace([np.inf], 0)
262262

263+
def conditional_time_to(self):
264+
"""
265+
Return a DataFrame, with index equal to survival_function_, that estimates the median
266+
duration remaining until the death event, given survival up until time t. For example, if an
267+
indivual exists until age 1, their expected life remaining *given they lived to time 1*
268+
might be 9 years.
269+
270+
Returns:
271+
conditional_time_to_: DataFrame, with index equal to survival_function_
272+
273+
"""
274+
age = self.survival_function_.index.values[:, None]
275+
columns = ['%s - Conditional time remaining to event' % self._label]
276+
return pd.DataFrame(qth_survival_times(self.survival_function_[self._label] * 0.5, self.survival_function_).T.sort(ascending=False).values,
277+
index=self.survival_function_.index,
278+
columns=columns) - age
279+
263280

264281
class BreslowFlemingHarringtonFitter(BaseFitter):
265282

@@ -868,7 +885,7 @@ def _get_efron_values(self, X, beta, T, E, include_likelihood=False):
868885
return hessian, gradient
869886

870887
def _newton_rhaphson(self, X, T, E, initial_beta=None, step_size=1.,
871-
epsilon=10e-5, show_progress=True):
888+
epsilon=10e-5, show_progress=True, include_likelihood=False):
872889
"""
873890
Newton Rhaphson algorithm for fitting CPH model.
874891
@@ -883,6 +900,7 @@ def _newton_rhaphson(self, X, T, E, initial_beta=None, step_size=1.,
883900
step_size: 0 < float <= 1 to determine a step size in NR algorithm.
884901
epsilon: the convergence halts if the norm of delta between
885902
successive positions is less than epsilon.
903+
include_likelihood: saves the final log-likelihood to the CoxPHFitter under _log_likelihood.
886904
887905
Returns:
888906
beta: (1,d) numpy array.
@@ -914,7 +932,8 @@ def _newton_rhaphson(self, X, T, E, initial_beta=None, step_size=1.,
914932
i = 1
915933
converging = True
916934
while converging:
917-
hessian, gradient = get_gradients(X, beta, T, E)
935+
output = get_gradients(X, beta, T, E, include_likelihood=include_likelihood)
936+
hessian, gradient = output[:2]
918937
delta = solve(-hessian, step_size * gradient.T)
919938
beta = delta + beta
920939
if pd.isnull(delta).sum() > 1:
@@ -928,12 +947,14 @@ def _newton_rhaphson(self, X, T, E, initial_beta=None, step_size=1.,
928947

929948
self._hessian_ = hessian
930949
self._score_ = gradient
950+
if include_likelihood:
951+
self._log_likelihood = output[2]
931952
if show_progress:
932953
print("Convergence completed after %d iterations." % (i))
933954
return beta
934955

935956
def fit(self, df, duration_col='T', event_col='E',
936-
show_progress=False, initial_beta=None):
957+
show_progress=False, initial_beta=None, include_likelihood=False):
937958
"""
938959
Fit the Cox Propertional Hazard model to a dataset. Tied survival times
939960
are handled using Efron's tie-method.
@@ -951,6 +972,9 @@ def fit(self, df, duration_col='T', event_col='E',
951972
diagnostics.
952973
initial_beta: initialize the starting point of the iterative
953974
algorithm. Default is the zero vector.
975+
include_likelihood: saves the final log-likelihood to the CoxPHFitter under
976+
the property _log_likelihood.
977+
954978
955979
Returns:
956980
self, with additional properties: hazards_
@@ -969,7 +993,8 @@ def fit(self, df, duration_col='T', event_col='E',
969993
self._check_values(df)
970994

971995
hazards_ = self._newton_rhaphson(df, T, E, initial_beta=initial_beta,
972-
show_progress=show_progress)
996+
show_progress=show_progress,
997+
include_likelihood=include_likelihood)
973998

974999
self.hazards_ = pd.DataFrame(hazards_.T, columns=df.columns,
9751000
index=['coef'])
@@ -1047,7 +1072,7 @@ def predict_percentile(self, X, p=0.5):
10471072
Returns the median lifetimes for the individuals.
10481073
http://stats.stackexchange.com/questions/102986/percentile-loss-functions
10491074
"""
1050-
return qth_survival_times(0.5, self.predict_survival_function(X))
1075+
return qth_survival_times(p, self.predict_survival_function(X))[p]
10511076

10521077
def predict_median(self, X):
10531078
"""
@@ -1201,17 +1226,25 @@ def qth_survival_times(q, survival_functions):
12011226
If numpy array, will return indices.
12021227
12031228
Returns:
1204-
v: an array containing the first times the value was crossed.
1205-
np.inf if infinity.
1229+
v: if d==1, returns a float, np.inf if infinity.
1230+
if d > 1, an DataFrame containing the first times the value was crossed.
1231+
1232+
"""
1233+
q = pd.Series(q)
1234+
survival_functions = pd.DataFrame(survival_functions)
1235+
if survival_functions.shape[1] == 1 and q.shape == (1,):
1236+
return survival_functions.apply(lambda s: qth_survival_time(q[0], s)).ix[0]
1237+
else:
1238+
return pd.DataFrame({_q: survival_functions.apply(lambda s: qth_survival_time(_q, s)) for _q in q})
1239+
1240+
1241+
def qth_survival_time(q, survival_function):
1242+
"""
1243+
Expects a Pandas series, returns the time when the qth probability is reached.
12061244
"""
1207-
assert 0. <= q <= 1., "q must be between 0. and 1."
1208-
sv_b = (1.0 * (survival_functions < q)).cumsum() > 0
1209-
try:
1210-
v = sv_b.idxmax(0)
1211-
v[sv_b.iloc[-1,:] == 0] = np.inf
1212-
except:
1213-
v = sv_b.argmax(0)
1214-
v[sv_b[-1,:] == 0] = np.inf
1245+
if survival_function.iloc[-1] > q:
1246+
return np.inf
1247+
v = (survival_function <= q).idxmax(0)
12151248
return v
12161249

12171250

lifelines/tests/test_suite.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from collections import Counter
1818
import matplotlib.pyplot as plt
1919
import pandas as pd
20+
from pandas.util.testing import assert_frame_equal
2021

2122
from ..estimation import KaplanMeierFitter, NelsonAalenFitter, AalenAdditiveFitter, \
2223
median_survival_times, BreslowFlemingHarringtonFitter, BayesianFitter, \
23-
CoxPHFitter
24+
CoxPHFitter, qth_survival_times, qth_survival_time
2425

2526
from ..statistics import (logrank_test, multivariate_logrank_test,
2627
pairwise_logrank_test, concordance_index)
@@ -33,6 +34,44 @@
3334

3435
class MiscTests(unittest.TestCase):
3536

37+
def test_qth_survival_times_with_varying_datatype_inputs(self):
38+
sf_list = [1.0, 0.75, 0.5, 0.25, 0.0]
39+
sf_array = np.array([1.0, 0.75, 0.5, 0.25, 0.0])
40+
sf_df_no_index = pd.DataFrame([1.0, 0.75, 0.5, 0.25, 0.0])
41+
sf_df_index = pd.DataFrame([1.0, 0.75, 0.5, 0.25, 0.0], index=[10, 20, 30, 40, 50])
42+
sf_series_index = pd.Series([1.0, 0.75, 0.5, 0.25, 0.0], index=[10, 20, 30, 40, 50])
43+
sf_series_no_index = pd.Series([1.0, 0.75, 0.5, 0.25, 0.0])
44+
45+
q = 0.5
46+
47+
assert qth_survival_times(q, sf_list) == 2
48+
assert qth_survival_times(q, sf_array) == 2
49+
assert qth_survival_times(q, sf_df_no_index) == 2
50+
assert qth_survival_times(q, sf_df_index) == 30
51+
assert qth_survival_times(q, sf_series_index) == 30
52+
assert qth_survival_times(q, sf_series_no_index) == 2
53+
54+
def test_qth_survival_times_multi_dim_input(self):
55+
sf = np.linspace(1, 0, 50)
56+
sf_multi_df = pd.DataFrame({'sf': sf, 'sf**2': sf ** 2})
57+
58+
medians = qth_survival_times(0.5, sf_multi_df)
59+
assert medians.ix['sf'][0.5] == 25
60+
assert medians.ix['sf**2'][0.5] == 15
61+
62+
def test_qth_survival_time_returns_inf(self):
63+
sf = pd.Series([1., 0.7, 0.6])
64+
assert qth_survival_time(0.5, sf) == np.inf
65+
66+
def test_qth_survival_times_with_multivariate_q(self):
67+
sf = np.linspace(1, 0, 50)
68+
sf_multi_df = pd.DataFrame({'sf': sf, 'sf**2': sf ** 2})
69+
70+
assert_frame_equal(qth_survival_times([0.2, 0.5], sf_multi_df), pd.DataFrame([[40, 25], [28, 15]], columns=[0.2, 0.5], index=['sf', 'sf**2']))
71+
assert_frame_equal(qth_survival_times([0.2, 0.5], sf_multi_df['sf']), pd.DataFrame([[40, 25]], columns=[0.2, 0.5], index=['sf']))
72+
assert_frame_equal(qth_survival_times(0.5, sf_multi_df), pd.DataFrame([[25], [15]], columns=[0.5], index=['sf', 'sf**2']))
73+
assert qth_survival_times(0.5, sf_multi_df['sf']) == 25
74+
3675
def test_datetimes_to_durations_days(self):
3776
start_date = ['2013-10-10 0:00:00', '2013-10-09', '2012-10-10']
3877
end_date = ['2013-10-13', '2013-10-10 0:00:00', '2013-10-15']
@@ -108,6 +147,19 @@ def test_cross_validator_with_predictor_and_kwargs(self):
108147
duration_col='T', event_col='E', k=3,
109148
predictor="predict_percentile", predictor_kwargs={'p': 0.6})
110149

150+
def test_label_is_a_property(self):
151+
kmf = KaplanMeierFitter()
152+
kmf.fit(LIFETIMES, label='Test Name')
153+
assert kmf._label == 'Test Name'
154+
assert kmf.confidence_interval_.columns[0] == 'Test Name_upper_0.95'
155+
assert kmf.confidence_interval_.columns[1] == 'Test Name_lower_0.95'
156+
157+
naf = NelsonAalenFitter()
158+
naf.fit(LIFETIMES, label='Test Name')
159+
assert naf._label == 'Test Name'
160+
assert naf.confidence_interval_.columns[0] == 'Test Name_upper_0.95'
161+
assert naf.confidence_interval_.columns[1] == 'Test Name_lower_0.95'
162+
111163

112164
class StatisticalTests(unittest.TestCase):
113165

@@ -140,7 +192,7 @@ def test_censor_kaplan_meier(self):
140192

141193
def test_median(self):
142194
sv = pd.DataFrame(1 - np.linspace(0, 1, 1000))
143-
self.assertTrue(median_survival_times(sv).ix[0] == 500)
195+
self.assertTrue(median_survival_times(sv) == 500)
144196

145197
def test_not_to_break(self):
146198
try:
@@ -731,6 +783,11 @@ def test_flat_style_no_censor(self):
731783

732784
class CoxRegressionTests(unittest.TestCase):
733785

786+
def test_log_likelihood_is_available_in_output(self):
787+
cox = CoxPHFitter()
788+
cox.fit(data_nus, duration_col='t', event_col='E', include_likelihood=True)
789+
assert abs( cox._log_likelihood - -12.7601409152 ) < 0.001
790+
734791
def test_efron_computed_by_hand_examples(self):
735792
cox = CoxPHFitter()
736793

lifelines/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from datetime import datetime
55

66
import numpy as np
7-
from numpy.random import permutation
87
import pandas as pd
98
from pandas import to_datetime
109

@@ -217,7 +216,7 @@ def datetimes_to_durations(start_times, end_times, fill_date=datetime.today(), f
217216
end_times = pd.Series(end_times).copy()
218217
start_times_ = to_datetime(start_times, dayfirst=dayfirst)
219218

220-
C = ~(pd.isnull(end_times).values + (end_times == "") + (end_times == na_values))
219+
C = ~(pd.isnull(end_times).values | (end_times == "") | (end_times == na_values))
221220
end_times[~C] = fill_date
222221
"""
223222
c = (to_datetime(end_times, dayfirst=dayfirst, coerce=True) > fill_date)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def read(fname):
2222

2323
setup(
2424
name="lifelines",
25-
version="0.4.2",
25+
version="0.4.3",
2626
author="Cameron Davidson-Pilon",
2727
author_email="[email protected]",
2828
description="Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression",

0 commit comments

Comments
 (0)