@@ -258,6 +258,80 @@ def test_wrong_knots_fails(
258258 ):
259259 knots .get_knot_info (n_times = 200 , knots = knots_arg , is_national = is_national )
260260
261+ @parameterized .named_parameters (
262+ dict (
263+ testcase_name = "geo" ,
264+ is_national = False ,
265+ n_times = 5 ,
266+ expected_n_knots = 5 ,
267+ expected_knot_locations = [0 , 1 , 2 , 3 , 4 ],
268+ ),
269+ dict (
270+ testcase_name = "national" ,
271+ is_national = True ,
272+ n_times = 5 ,
273+ expected_n_knots = 1 ,
274+ expected_knot_locations = [0 ],
275+ ),
276+ )
277+ def test_get_knot_info_aks_returns_empty_falls_back_to_defaults (
278+ self ,
279+ is_national ,
280+ n_times ,
281+ expected_n_knots ,
282+ expected_knot_locations ,
283+ ):
284+ """Tests that if AKS returns empty knots, we fall back to default logic."""
285+ mock_result = mock .create_autospec (knots .AKSResult , instance = True )
286+ mock_result .knots = np .array ([], dtype = int )
287+ mock_aks = self .enter_context (
288+ mock .patch .object (
289+ knots .AKS , "automatic_knot_selection" , autospec = True , spec_set = True
290+ )
291+ )
292+ mock_aks .return_value = mock_result
293+
294+ info = knots .get_knot_info (
295+ n_times = n_times ,
296+ knots = None ,
297+ enable_aks = True ,
298+ data = mock .create_autospec (
299+ input_data .InputData , instance = True , spec_set = True
300+ ),
301+ is_national = is_national ,
302+ )
303+
304+ self .assertEqual (info .n_knots , expected_n_knots )
305+ np .testing .assert_array_equal (info .knot_locations , expected_knot_locations )
306+
307+ def test_get_knot_info_aks_returns_knots_uses_them (self ):
308+ mock_result = mock .create_autospec (
309+ knots .AKSResult ,
310+ instance = True ,
311+ )
312+ mock_result .knots = np .array ([2 , 4 ], dtype = int )
313+ mock_aks = self .enter_context (
314+ mock .patch .object (
315+ knots .AKS ,
316+ "automatic_knot_selection" ,
317+ autospec = True ,
318+ spec_set = True ,
319+ )
320+ )
321+ mock_aks .return_value = mock_result
322+
323+ info = knots .get_knot_info (
324+ n_times = 10 ,
325+ knots = None ,
326+ enable_aks = True ,
327+ data = mock .create_autospec (
328+ input_data .InputData , instance = True , spec_set = True
329+ ),
330+ )
331+
332+ self .assertEqual (info .n_knots , 2 )
333+ np .testing .assert_array_equal (info .knot_locations , [2 , 4 ])
334+
261335
262336class AKSTest (parameterized .TestCase ):
263337 """Tests for knots.AKS class."""
@@ -389,7 +463,23 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
389463 ),
390464 "non_revenue" ,
391465 ),
392- expected_knots = [11 , 14 , 38 , 39 , 41 , 43 , 45 , 48 , 50 , 55 , 87 , 89 , 90 ],
466+ expected_knots = [
467+ 0 ,
468+ 11 ,
469+ 14 ,
470+ 38 ,
471+ 39 ,
472+ 41 ,
473+ 43 ,
474+ 45 ,
475+ 48 ,
476+ 50 ,
477+ 55 ,
478+ 87 ,
479+ 89 ,
480+ 90 ,
481+ 116 ,
482+ ],
393483 ),
394484 dict (
395485 testcase_name = "national_geos" ,
@@ -404,6 +494,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
404494 "non_revenue" ,
405495 ),
406496 expected_knots = [
497+ 0 ,
407498 1 ,
408499 2 ,
409500 3 ,
@@ -457,6 +548,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
457548 103 ,
458549 104 ,
459550 114 ,
551+ 116 ,
460552 ],
461553 ),
462554 dict (
@@ -472,6 +564,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
472564 "non_revenue" ,
473565 ),
474566 expected_knots = [
567+ 0 ,
475568 4 ,
476569 17 ,
477570 20 ,
@@ -490,6 +583,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
490583 77 ,
491584 78 ,
492585 81 ,
586+ 116 ,
493587 ],
494588 ),
495589 dict (
@@ -504,7 +598,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
504598 ),
505599 "non_revenue" ,
506600 ),
507- expected_knots = [2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 ],
601+ expected_knots = [0 , 2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 , 116 ],
508602 ),
509603 dict (
510604 testcase_name = "50_times" ,
@@ -518,7 +612,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
518612 ),
519613 "non_revenue" ,
520614 ),
521- expected_knots = [1 , 5 , 13 , 15 , 16 , 23 , 27 , 31 , 32 , 38 , 42 ],
615+ expected_knots = [0 , 1 , 5 , 13 , 15 , 16 , 23 , 27 , 31 , 32 , 38 , 42 , 49 ],
522616 ),
523617 dict (
524618 testcase_name = "200_times" ,
@@ -533,6 +627,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
533627 "non_revenue" ,
534628 ),
535629 expected_knots = [
630+ 0 ,
536631 4 ,
537632 10 ,
538633 12 ,
@@ -579,6 +674,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
579674 195 ,
580675 196 ,
581676 197 ,
677+ 199 ,
582678 ],
583679 ),
584680 dict (
@@ -594,7 +690,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
594690 ),
595691 "non_revenue" ,
596692 ),
597- expected_knots = [17 , 25 ],
693+ expected_knots = [0 , 17 , 25 , 49 ],
598694 ),
599695 dict (
600696 testcase_name = "seasonal" ,
@@ -610,6 +706,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
610706 "non_revenue" ,
611707 ),
612708 expected_knots = [
709+ 0 ,
613710 1 ,
614711 4 ,
615712 5 ,
@@ -629,6 +726,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
629726 41 ,
630727 45 ,
631728 47 ,
729+ 49 ,
632730 ],
633731 ),
634732 dict (
@@ -644,7 +742,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
644742 ),
645743 "non_revenue" ,
646744 ),
647- expected_knots = [24 , 25 , 26 ],
745+ expected_knots = [0 , 24 , 25 , 26 , 49 ],
648746 ),
649747 dict (
650748 testcase_name = "minimal_initial_knots" ,
@@ -658,7 +756,7 @@ def test_aks_internal_knots_guardrail_raises(self, data, expected_error):
658756 ),
659757 "non_revenue" ,
660758 ),
661- expected_knots = [3 ],
759+ expected_knots = [0 , 3 , 14 ],
662760 ),
663761 )
664762 def test_aks (self , data : input_data .InputData , expected_knots : list [int ]):
@@ -782,6 +880,7 @@ def test_user_provided_base_penalty(self):
782880 self .assertListEqual (
783881 actual_knots .tolist (),
784882 [
883+ 0 ,
785884 2 ,
786885 7 ,
787886 8 ,
@@ -841,6 +940,7 @@ def test_user_provided_base_penalty(self):
841940 110 ,
842941 113 ,
843942 114 ,
943+ 116 ,
844944 ],
845945 )
846946
@@ -849,13 +949,13 @@ def test_user_provided_base_penalty(self):
849949 testcase_name = "min_equals_max" ,
850950 min_internal_knots = 8 ,
851951 max_internal_knots = 8 ,
852- expected_knots = [2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 ],
952+ expected_knots = [0 , 2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 , 116 ],
853953 ),
854954 dict (
855955 testcase_name = "min_lt_max_" ,
856956 min_internal_knots = 2 ,
857957 max_internal_knots = 15 ,
858- expected_knots = [2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 ],
958+ expected_knots = [0 , 2 , 7 , 24 , 25 , 38 , 39 , 49 , 114 , 116 ],
859959 ),
860960 )
861961 def test_aks_user_provided_min_max_internal_knots (
0 commit comments