Skip to content

Commit ec73bd7

Browse files
fix some tests and use initial conditions
1 parent 9145fbb commit ec73bd7

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

lifelines/fitters/piecewise_exponential_fitter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import autograd.numpy as np
33
from lifelines.fitters import KnownModelParametricUnivariateFitter
4+
from lifelines import utils
45

56

67
class PiecewiseExponentialFitter(KnownModelParametricUnivariateFitter):
@@ -81,6 +82,7 @@ def __init__(self, breakpoints, *args, **kwargs):
8182
super(PiecewiseExponentialFitter, self).__init__(*args, **kwargs)
8283

8384
def _cumulative_hazard(self, params, times):
85+
times = np.atleast_1d(times)
8486
n = times.shape[0]
8587
times = times.reshape((n, 1))
8688
bp = self.breakpoints

lifelines/fitters/spline_fitter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(self, knot_locations: np.ndarray, *args, **kwargs):
9090
self._bounds = [(None, None)] * (self.n_knots)
9191
super(SplineFitter, self).__init__(*args, **kwargs)
9292

93+
def _create_initial_point(self, Ts, E, entry, weights):
94+
return 0.1 * np.ones(self.n_knots)
95+
9396
def _cumulative_hazard(self, params, t):
9497
phis = params
9598
lT = np.log(t)

lifelines/tests/test_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs):
118118

119119
class SplineFitterTesting(SplineFitter):
120120
def __init__(self, *args, **kwargs):
121-
super(SplineFitterTesting, self).__init__([0.0, 50.0], *args, **kwargs)
121+
super(SplineFitterTesting, self).__init__([0.0, 40.0], *args, **kwargs)
122122

123123

124124
class CustomRegressionModelTesting(ParametricRegressionFitter):

0 commit comments

Comments
 (0)