Skip to content

Commit 1deeb2b

Browse files
riemanliThe Meridian Authors
authored andcommitted
If AKS selects no internal knots, the model degenerates to a single common intercept across all time points
PiperOrigin-RevId: 865082316
1 parent 18be981 commit 1deeb2b

File tree

3 files changed

+131
-22
lines changed

3 files changed

+131
-22
lines changed

meridian/data/test_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,13 +1953,12 @@ def sample_input_data_for_aks_with_expected_knot_info() -> (
19531953
),
19541954
'non_revenue',
19551955
)
1956+
expected_knots = np.array(
1957+
[0, 11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90, 116]
1958+
)
19561959
expected_knot_info = knots.KnotInfo(
1957-
n_knots=13,
1958-
knot_locations=np.array(
1959-
[11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90]
1960-
),
1961-
weights=knots.l1_distance_weights(
1962-
117, np.array([11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90])
1963-
),
1960+
n_knots=15,
1961+
knot_locations=expected_knots,
1962+
weights=knots.l1_distance_weights(117, expected_knots),
19641963
)
19651964
return data, expected_knot_info

meridian/model/knots.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,9 @@ def get_knot_info(
202202
)
203203
else:
204204
aks = AKS(data)
205-
knots = aks.automatic_knot_selection().knots
206-
n_knots = len(knots)
207-
knot_locations = knots
208-
elif isinstance(knots, int):
205+
selected_knots = aks.automatic_knot_selection().knots
206+
knots = selected_knots if selected_knots.size > 0 else None
207+
if isinstance(knots, int):
209208
if knots < 1:
210209
raise ValueError('If knots is an integer, it must be at least 1.')
211210
elif knots > n_times:
@@ -220,7 +219,7 @@ def get_knot_info(
220219
)
221220
n_knots = knots
222221
knot_locations = _get_equally_spaced_knot_locations(n_times, n_knots)
223-
elif isinstance(knots, Collection) and knots:
222+
elif isinstance(knots, Collection) and len(knots) > 0:
224223
if any(k < 0 for k in knots):
225224
raise ValueError('Knots must be all non-negative.')
226225
if any(k >= n_times for k in knots):
@@ -278,8 +277,11 @@ def automatic_knot_selection(
278277
value will be used.
279278
280279
Returns:
281-
Selected knots and the corresponding B-spline model.
280+
Selected knots and the corresponding B-spline model. If at least one knot
281+
is selected, boundary knots (min and max time) are added to ensure full
282+
time coverage.
282283
"""
284+
283285
if base_penalty is None:
284286
base_penalty = self._BASE_PENALTY
285287
n_times = len(self._data.time)
@@ -326,7 +328,15 @@ def automatic_knot_selection(
326328
np.where(information_criterion == min(information_criterion))[0]
327329
)
328330

329-
return AKSResult(knots_sel[opt_idx], model[opt_idx])
331+
selected_knots = knots_sel[opt_idx]
332+
if selected_knots.size > 0:
333+
start_knot = int(x.min())
334+
end_knot = int(x.max())
335+
selected_knots = np.unique(
336+
np.concatenate((selected_knots, [start_knot, end_knot]))
337+
)
338+
339+
return AKSResult(selected_knots, model[opt_idx])
330340

331341
def _get_bspline_matrix(self, x, knots):
332342
"""Replaces patsy.highlevel.dmatrix('bs(...)', ...)"""

meridian/model/knots_test.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,80 @@ def test_wrong_knots_fails(
258258
):
259259
knots.get_knot_info(n_times=200, knots=knots_arg, is_national=is_national)
260260

261+
@parameterized.named_parameters(
262+
dict(
263+
testcase_name="geo",
264+
is_national=False,
265+
n_times=5,
266+
expected_n_knots=5,
267+
expected_knot_locations=[0, 1, 2, 3, 4],
268+
),
269+
dict(
270+
testcase_name="national",
271+
is_national=True,
272+
n_times=5,
273+
expected_n_knots=1,
274+
expected_knot_locations=[0],
275+
),
276+
)
277+
def test_get_knot_info_aks_returns_empty_falls_back_to_defaults(
278+
self,
279+
is_national,
280+
n_times,
281+
expected_n_knots,
282+
expected_knot_locations,
283+
):
284+
"""Tests that if AKS returns empty knots, we fall back to default logic."""
285+
mock_result = mock.create_autospec(knots.AKSResult, instance=True)
286+
mock_result.knots = np.array([], dtype=int)
287+
mock_aks = self.enter_context(
288+
mock.patch.object(
289+
knots.AKS, "automatic_knot_selection", autospec=True, spec_set=True
290+
)
291+
)
292+
mock_aks.return_value = mock_result
293+
294+
info = knots.get_knot_info(
295+
n_times=n_times,
296+
knots=None,
297+
enable_aks=True,
298+
data=mock.create_autospec(
299+
input_data.InputData, instance=True, spec_set=True
300+
),
301+
is_national=is_national,
302+
)
303+
304+
self.assertEqual(info.n_knots, expected_n_knots)
305+
np.testing.assert_array_equal(info.knot_locations, expected_knot_locations)
306+
307+
def test_get_knot_info_aks_returns_knots_uses_them(self):
308+
mock_result = mock.create_autospec(
309+
knots.AKSResult,
310+
instance=True,
311+
)
312+
mock_result.knots = np.array([2, 4], dtype=int)
313+
mock_aks = self.enter_context(
314+
mock.patch.object(
315+
knots.AKS,
316+
"automatic_knot_selection",
317+
autospec=True,
318+
spec_set=True,
319+
)
320+
)
321+
mock_aks.return_value = mock_result
322+
323+
info = knots.get_knot_info(
324+
n_times=10,
325+
knots=None,
326+
enable_aks=True,
327+
data=mock.create_autospec(
328+
input_data.InputData, instance=True, spec_set=True
329+
),
330+
)
331+
332+
self.assertEqual(info.n_knots, 2)
333+
np.testing.assert_array_equal(info.knot_locations, [2, 4])
334+
261335

262336
class AKSTest(parameterized.TestCase):
263337
"""Tests for knots.AKS class."""
@@ -389,7 +463,23 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
389463
),
390464
"non_revenue",
391465
),
392-
expected_knots=[11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90],
466+
expected_knots=[
467+
0,
468+
11,
469+
14,
470+
38,
471+
39,
472+
41,
473+
43,
474+
45,
475+
48,
476+
50,
477+
55,
478+
87,
479+
89,
480+
90,
481+
116,
482+
],
393483
),
394484
dict(
395485
testcase_name="national_geos",
@@ -404,6 +494,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
404494
"non_revenue",
405495
),
406496
expected_knots=[
497+
0,
407498
1,
408499
2,
409500
3,
@@ -457,6 +548,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
457548
103,
458549
104,
459550
114,
551+
116,
460552
],
461553
),
462554
dict(
@@ -472,6 +564,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
472564
"non_revenue",
473565
),
474566
expected_knots=[
567+
0,
475568
4,
476569
17,
477570
20,
@@ -490,6 +583,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
490583
77,
491584
78,
492585
81,
586+
116,
493587
],
494588
),
495589
dict(
@@ -504,7 +598,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
504598
),
505599
"non_revenue",
506600
),
507-
expected_knots=[2, 7, 24, 25, 38, 39, 49, 114],
601+
expected_knots=[0, 2, 7, 24, 25, 38, 39, 49, 114, 116],
508602
),
509603
dict(
510604
testcase_name="50_times",
@@ -518,7 +612,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
518612
),
519613
"non_revenue",
520614
),
521-
expected_knots=[1, 5, 13, 15, 16, 23, 27, 31, 32, 38, 42],
615+
expected_knots=[0, 1, 5, 13, 15, 16, 23, 27, 31, 32, 38, 42, 49],
522616
),
523617
dict(
524618
testcase_name="200_times",
@@ -533,6 +627,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
533627
"non_revenue",
534628
),
535629
expected_knots=[
630+
0,
536631
4,
537632
10,
538633
12,
@@ -579,6 +674,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
579674
195,
580675
196,
581676
197,
677+
199,
582678
],
583679
),
584680
dict(
@@ -594,7 +690,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
594690
),
595691
"non_revenue",
596692
),
597-
expected_knots=[17, 25],
693+
expected_knots=[0, 17, 25, 49],
598694
),
599695
dict(
600696
testcase_name="seasonal",
@@ -610,6 +706,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
610706
"non_revenue",
611707
),
612708
expected_knots=[
709+
0,
613710
1,
614711
4,
615712
5,
@@ -629,6 +726,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
629726
41,
630727
45,
631728
47,
729+
49,
632730
],
633731
),
634732
dict(
@@ -644,7 +742,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
644742
),
645743
"non_revenue",
646744
),
647-
expected_knots=[24, 25, 26],
745+
expected_knots=[0, 24, 25, 26, 49],
648746
),
649747
dict(
650748
testcase_name="minimal_initial_knots",
@@ -658,7 +756,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
658756
),
659757
"non_revenue",
660758
),
661-
expected_knots=[3],
759+
expected_knots=[0, 3, 14],
662760
),
663761
)
664762
def test_aks(self, data: input_data.InputData, expected_knots: list[int]):
@@ -782,6 +880,7 @@ def test_user_provided_base_penalty(self):
782880
self.assertListEqual(
783881
actual_knots.tolist(),
784882
[
883+
0,
785884
2,
786885
7,
787886
8,
@@ -841,6 +940,7 @@ def test_user_provided_base_penalty(self):
841940
110,
842941
113,
843942
114,
943+
116,
844944
],
845945
)
846946

@@ -849,13 +949,13 @@ def test_user_provided_base_penalty(self):
849949
testcase_name="min_equals_max",
850950
min_internal_knots=8,
851951
max_internal_knots=8,
852-
expected_knots=[2, 7, 24, 25, 38, 39, 49, 114],
952+
expected_knots=[0, 2, 7, 24, 25, 38, 39, 49, 114, 116],
853953
),
854954
dict(
855955
testcase_name="min_lt_max_",
856956
min_internal_knots=2,
857957
max_internal_knots=15,
858-
expected_knots=[2, 7, 24, 25, 38, 39, 49, 114],
958+
expected_knots=[0, 2, 7, 24, 25, 38, 39, 49, 114, 116],
859959
),
860960
)
861961
def test_aks_user_provided_min_max_internal_knots(

0 commit comments

Comments
 (0)