Skip to content

Commit a43f62d

Browse files
V0.18.5 (#633)
* v0.18.5 * docs fix kmf and naf plotting * bump version * lint * better image * fix plotting test
1 parent 5a09f11 commit a43f62d

17 files changed

+523
-190
lines changed

.travis.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ script:
2222
- make test
2323
after_success:
2424
- coveralls
25-
# run linter but don't fail for errors
26-
- make lint
2725
# Don't want notifications
2826
notifications:
2927
email: false

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
### Changelogs
22

3+
### 0.18.5
4+
- added new plotting methods to parametric univariate models: `plot_survival_function`, `plot_hazard` and `plot_cumulative_hazard`. The last one is an alias for `plot`.
5+
- added new properties to parametric univarite models: `confidence_interval_survival_function_`, `confidence_interval_hazard_`, `confidence_interval_cumulative_hazard_`. The last one is an alias for `confidence_interval_`.
6+
- Fixed some overflow issues with `AalenJohansenFitter`'s variance calculations when using large datasets.
7+
- Fixed an edgecase in `AalenJohansenFitter` that causing some datasets with to be jittered too often.
8+
- Add a new kwarg to `AalenJohansenFitter`, `calculate_variance` that can be used to turn off variance calculations since this can take a long time for large datasets. Thanks @pzivich!
9+
310
### 0.18.4
411
- fixed confidence intervals in cumulative hazards for parametric univarite models. They were previously
512
serverly depressed.

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ lint:
1414
ifeq ($(TRAVIS_PYTHON_VERSION), 2.7)
1515
echo "Skip linting for Python2.7"
1616
else
17+
black lifelines/ -l 120 --fast
18+
black tests/ -l 120 --fast
1719
prospector --output-format grouped
1820
endif
1921

docs/Survival analysis with lifelines.rst

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,17 +570,45 @@ Similarly, there are other parametric models in *lifelines*. Generally, which pa
570570
llf = LogLogisticFitter().fit(T, E, label='LogLogisticFitter')
571571
pwf = PiecewiseExponentialFitter([40, 60]).fit(T, E, label='PiecewiseExponentialFitter')
572572
573-
wbf.plot(ax=axes[0][0])
574-
exf.plot(ax=axes[0][1])
575-
lnf.plot(ax=axes[0][2])
576-
naf.plot(ax=axes[1][0])
577-
llf.plot(ax=axes[1][1])
578-
pwf.plot(ax=axes[1][2])
573+
wbf.plot_cumulative_hazard(ax=axes[0][0])
574+
exf.plot_cumulative_hazard(ax=axes[0][1])
575+
lnf.plot_cumulative_hazard(ax=axes[0][2])
576+
naf.plot_cumulative_hazard(ax=axes[1][0])
577+
llf.plot_cumulative_hazard(ax=axes[1][1])
578+
pwf.plot_cumulative_hazard(ax=axes[1][2])
579579
580580
.. image:: images/waltons_cumulative_hazard.png
581581

582582
*lifelines* can also be used to define your own parametic model. There is a tutorial on this available, see `Piecewise Exponential Models and Creating Custom Models`_.
583583

584+
Parametric models can also be used to create and plot the survival function, too. Below we compare the parametic models versus the non-parametric Kaplan-Meier estimate:
585+
586+
.. code:: python
587+
588+
from lifelines import KaplanMeierFitter
589+
590+
fig, axes = plt.subplots(2, 3, figsize=(9, 5))
591+
592+
T = data['T']
593+
E = data['E']
594+
595+
kmf = KaplanMeierFitter().fit(T, E, label='KaplanMeierFitter')
596+
wbf = WeibullFitter().fit(T, E, label='WeibullFitter')
597+
exf = ExponentialFitter().fit(T, E, label='ExponentalFitter')
598+
lnf = LogNormalFitter().fit(T, E, label='LogNormalFitter')
599+
llf = LogLogisticFitter().fit(T, E, label='LogLogisticFitter')
600+
pwf = PiecewiseExponentialFitter([40, 60]).fit(T, E, label='PiecewiseExponentialFitter')
601+
602+
wbf.plot_survival_function(ax=axes[0][0])
603+
exf.plot_survival_function(ax=axes[0][1])
604+
lnf.plot_survival_function(ax=axes[0][2])
605+
kmf.plot_survival_function(ax=axes[1][0])
606+
llf.plot_survival_function(ax=axes[1][1])
607+
pwf.plot_survival_function(ax=axes[1][2])
608+
609+
.. image:: images/waltons_survival_function.png
610+
611+
584612
Other types of censoring
585613
''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
586614

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
#
6161
# The short X.Y version.
6262

63-
version = "0.18.4"
63+
version = "0.18.5"
6464
# The full version, including dev info
6565
release = version
6666

80.3 KB
Loading

docs/jupyter_notebooks/Modelling time-lagged conversion rates.ipynb

Lines changed: 192 additions & 42 deletions
Large diffs are not rendered by default.

docs/jupyter_notebooks/Piecewise Exponential Models and Creating Custom Models.ipynb

Lines changed: 75 additions & 47 deletions
Large diffs are not rendered by default.

lifelines/fitters/__init__.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numpy.linalg import inv, pinv
2121

2222

23-
from lifelines.plotting import plot_estimate
23+
from lifelines.plotting import _plot_estimate
2424
from lifelines.utils import (
2525
qth_survival_times,
2626
_to_array,
@@ -84,18 +84,20 @@ def _update_docstrings(self):
8484
self._estimate_name, self.__class__.__name__
8585
)
8686
self.__class__.predict.__func__.__doc__ = self.predict.__doc__.format(self.__class__.__name__)
87-
self.__class__.plot.__func__.__doc__ = plot_estimate.__doc__.format(
87+
self.__class__.plot.__func__.__doc__ = _plot_estimate.__doc__.format(
8888
self.__class__.__name__, self._estimate_name
8989
)
9090
elif PY3:
9191
self.__class__.subtract.__doc__ = self.subtract.__doc__.format(self._estimate_name, self.__class__.__name__)
9292
self.__class__.divide.__doc__ = self.divide.__doc__.format(self._estimate_name, self.__class__.__name__)
9393
self.__class__.predict.__doc__ = self.predict.__doc__.format(self.__class__.__name__)
94-
self.__class__.plot.__doc__ = plot_estimate.__doc__.format(self.__class__.__name__, self._estimate_name)
94+
self.__class__.plot.__doc__ = _plot_estimate.__doc__.format(self.__class__.__name__, self._estimate_name)
9595

9696
@_must_call_fit_first
97-
def plot(self, *args, **kwargs):
98-
return plot_estimate(self, *args, **kwargs)
97+
def plot(self, **kwargs):
98+
return _plot_estimate(
99+
self, estimate=getattr(self, self._estimate_name), confidence_intervals=self.confidence_interval_, **kwargs
100+
)
99101

100102
@_must_call_fit_first
101103
def subtract(self, other):
@@ -204,17 +206,29 @@ def _conditional_time_to_event_(self):
204206
)
205207

206208
@_must_call_fit_first
207-
def hazard_at_times(self, times):
209+
def hazard_at_times(self, times, label=None):
208210
raise NotImplementedError
209211

210212
@_must_call_fit_first
211-
def survival_function_at_times(self, times):
213+
def survival_function_at_times(self, times, label=None):
212214
raise NotImplementedError
213215

214216
@_must_call_fit_first
215-
def cumulative_hazard_at_times(self, times):
217+
def cumulative_hazard_at_times(self, times, label=None):
216218
raise NotImplementedError
217219

220+
@_must_call_fit_first
221+
def plot_cumulative_hazard(self, **kwargs):
222+
raise NotImplementedError()
223+
224+
@_must_call_fit_first
225+
def plot_survival_function(self, **kwargs):
226+
raise NotImplementedError()
227+
228+
@_must_call_fit_first
229+
def plot_hazard(self, **kwargs):
230+
raise NotImplementedError()
231+
218232

219233
class ParametericUnivariateFitter(UnivariateFitter):
220234
"""
@@ -228,7 +242,6 @@ class ParametericUnivariateFitter(UnivariateFitter):
228242
def __init__(self, *args, **kwargs):
229243
super(ParametericUnivariateFitter, self).__init__(*args, **kwargs)
230244
self._estimate_name = "cumulative_hazard_"
231-
self.plot_cumulative_hazard = self.plot
232245
if not hasattr(self, "_hazard"):
233246
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
234247
self._hazard = egrad(self._cumulative_hazard, argnum=1)
@@ -302,9 +315,9 @@ def _buffer_bounds(self, bounds):
302315
if lb is None and ub is None:
303316
yield (None, None)
304317
elif lb is None:
305-
yield (None, self._MIN_PARAMETER_VALUE)
318+
yield (None, ub - self._MIN_PARAMETER_VALUE)
306319
elif ub is None:
307-
yield (self._MIN_PARAMETER_VALUE, None)
320+
yield (lb + self._MIN_PARAMETER_VALUE, None)
308321
else:
309322
yield (lb + self._MIN_PARAMETER_VALUE, ub - self._MIN_PARAMETER_VALUE)
310323

@@ -327,13 +340,31 @@ def _negative_log_likelihood(self, params, T, E, entry):
327340
return -ll / n
328341

329342
def _compute_confidence_bounds_of_cumulative_hazard(self, alpha, ci_labels):
343+
return self._compute_confidence_bounds_of_transform(self._cumulative_hazard, alpha, ci_labels)
344+
345+
def _compute_confidence_bounds_of_transform(self, transform, alpha, ci_labels):
346+
"""
347+
This computes the confidence intervals of a transform of the parameters. Ex: take
348+
the fitted parameters, a function/transform and the variance matrix and give me
349+
back confidence intervals of the transform.
350+
351+
Parameters
352+
-----------
353+
transform: function
354+
must a function of two parameters:
355+
``params``, an iterable that stores the parameters
356+
``times``, a numpy vector representing some timeline
357+
the function must use autograd imports (scipy and numpy)
358+
alpha: float
359+
confidence level
360+
ci_labels: tuple
361+
362+
"""
330363
alpha2 = inv_normal_cdf((1.0 + alpha) / 2.0)
331364
df = pd.DataFrame(index=self.timeline)
332365

333366
# pylint: disable=no-value-for-parameter
334-
gradient_of_cum_hazard_at_mle = make_jvp_reversemode(self._cumulative_hazard)(
335-
self._fitted_parameters_, self.timeline
336-
)
367+
gradient_of_cum_hazard_at_mle = make_jvp_reversemode(transform)(self._fitted_parameters_, self.timeline)
337368

338369
gradient_at_times = np.vstack(
339370
[gradient_of_cum_hazard_at_mle(basis) for basis in np.eye(len(self._fitted_parameters_))]
@@ -346,8 +377,9 @@ def _compute_confidence_bounds_of_cumulative_hazard(self, alpha, ci_labels):
346377
if ci_labels is None:
347378
ci_labels = ["%s_upper_%.2f" % (self._label, alpha), "%s_lower_%.2f" % (self._label, alpha)]
348379
assert len(ci_labels) == 2, "ci_labels should be a length 2 array."
349-
df[ci_labels[0]] = self.cumulative_hazard_at_times(self.timeline) + alpha2 * std_cumulative_hazard
350-
df[ci_labels[1]] = self.cumulative_hazard_at_times(self.timeline) - alpha2 * std_cumulative_hazard
380+
381+
df[ci_labels[0]] = transform(self._fitted_parameters_, self.timeline) + alpha2 * std_cumulative_hazard
382+
df[ci_labels[1]] = transform(self._fitted_parameters_, self.timeline) - alpha2 * std_cumulative_hazard
351383
return df
352384

353385
def _fit_model(self, T, E, entry, show_progress=True):
@@ -538,7 +570,8 @@ def fit(
538570
self.timeline = np.linspace(self.durations.min(), self.durations.max(), self.durations.shape[0])
539571

540572
self._label = label
541-
alpha = alpha if alpha is not None else self.alpha
573+
self._ci_labels = ci_labels
574+
self.alpha = coalesce(alpha, self.alpha)
542575

543576
# estimation
544577
self._fitted_parameters_, self._log_likelihood, self._hessian_ = self._fit_model(
@@ -576,30 +609,71 @@ def fit(
576609
self._predict_label = label
577610
self._update_docstrings()
578611

579-
self.survival_function_ = self.survival_function_at_times(self.timeline).to_frame(name=self._label)
580-
self.hazard_ = self.hazard_at_times(self.timeline).to_frame(self._label)
581-
self.cumulative_hazard_ = self.cumulative_hazard_at_times(self.timeline).to_frame(self._label)
612+
self.survival_function_ = self.survival_function_at_times(self.timeline).to_frame()
613+
self.hazard_ = self.hazard_at_times(self.timeline).to_frame()
614+
self.cumulative_hazard_ = self.cumulative_hazard_at_times(self.timeline).to_frame()
582615

583-
self.confidence_interval_ = self._compute_confidence_bounds_of_cumulative_hazard(alpha, ci_labels)
584616
return self
585617

586618
@_must_call_fit_first
587-
def survival_function_at_times(self, times):
588-
return pd.Series(self._survival_function(self._fitted_parameters_, times), index=_to_array(times))
619+
def survival_function_at_times(self, times, label=None):
620+
label = coalesce(label, self._label)
621+
return pd.Series(self._survival_function(self._fitted_parameters_, times), index=_to_array(times), name=label)
589622

590623
@_must_call_fit_first
591-
def cumulative_hazard_at_times(self, times):
592-
return pd.Series(self._cumulative_hazard(self._fitted_parameters_, times), index=_to_array(times))
624+
def cumulative_hazard_at_times(self, times, label=None):
625+
label = coalesce(label, self._label)
626+
return pd.Series(self._cumulative_hazard(self._fitted_parameters_, times), index=_to_array(times), name=label)
593627

594628
@_must_call_fit_first
595-
def hazard_at_times(self, times):
596-
return pd.Series(self._hazard(self._fitted_parameters_, times), index=_to_array(times))
629+
def hazard_at_times(self, times, label=None):
630+
label = coalesce(label, self._label)
631+
return pd.Series(self._hazard(self._fitted_parameters_, times), index=_to_array(times), name=label)
597632

598633
@property
599634
@_must_call_fit_first
600635
def median_(self):
601636
return median_survival_times(self.survival_function_)
602637

638+
@property
639+
@_must_call_fit_first
640+
def confidence_interval_(self):
641+
return self._compute_confidence_bounds_of_cumulative_hazard(self.alpha, self._ci_labels)
642+
643+
@property
644+
@_must_call_fit_first
645+
def confidence_interval_cumulative_hazard_(self):
646+
return self.confidence_interval_
647+
648+
@property
649+
@_must_call_fit_first
650+
def confidence_interval_hazard_(self):
651+
return self._compute_confidence_bounds_of_transform(self._hazard, self.alpha, self._ci_labels)
652+
653+
@property
654+
@_must_call_fit_first
655+
def confidence_interval_survival_function_(self):
656+
return self._compute_confidence_bounds_of_transform(self._survival_function, self.alpha, self._ci_labels)
657+
658+
@_must_call_fit_first
659+
def plot_cumulative_hazard(self, **kwargs):
660+
return self.plot(**kwargs)
661+
662+
@_must_call_fit_first
663+
def plot_survival_function(self, **kwargs):
664+
return _plot_estimate(
665+
self,
666+
estimate=getattr(self, "survival_function_"),
667+
confidence_intervals=self.confidence_interval_survival_function_,
668+
**kwargs
669+
)
670+
671+
@_must_call_fit_first
672+
def plot_hazard(self, **kwargs):
673+
return _plot_estimate(
674+
self, estimate=getattr(self, "hazard_"), confidence_intervals=self.confidence_interval_hazard_, **kwargs
675+
)
676+
603677

604678
class KnownModelParametericUnivariateFitter(ParametericUnivariateFitter):
605679

0 commit comments

Comments
 (0)