@@ -22,6 +22,10 @@ def create_test_data(
2222
2323 Returns:
2424 pd.DataFrame: Test data in long format.
25+
26+ Raises:
27+ ValueError: If `freq` is not one of "monthly", "quarterly", or "yearly".
28+
2529 """
2630 rng = np .random .default_rng (seed ) if seed else np .random .default_rng ()
2731
@@ -31,40 +35,21 @@ def create_test_data(
3135 industry_codes = ["B" , "C" , "F" , "G" , "H" , "J" , "M" , "N" , "S" ]
3236 industries = rng .choice (industry_codes , size = n , replace = True )
3337
34- # Generate time periods
38+ # Generate time periods as List[str]
3539 if freq == "monthly" :
36- time_periods = (
37- pd .date_range (
38- start = "2020-01-01" ,
39- periods = n_periods ,
40- freq = "ME" ,
41- )
42- .to_period ("M" )
43- .astype (str )
44- )
45- if freq == "quarterly" :
46- time_periods_str = (
47- pd .date_range (
48- start = "2020-01-01" ,
49- periods = n_periods ,
50- freq = "QE" ,
51- )
52- .to_period ("Q" )
53- .astype (str )
54- )
55- time_periods = [f"{ p [:4 ]} -Q{ p [5 :]} " for p in time_periods_str .tolist ()]
56- if freq == "yearly" :
57- time_periods = (
58- pd .date_range (
59- start = "2020-01-01" ,
60- periods = n_periods ,
61- freq = "YE" ,
62- )
63- .to_period ("Y" )
64- .astype (str )
65- )
66-
67- # Create Cartesian product of industries and periods
40+ periods = pd .period_range (start = "2020-01-01" , periods = n_periods , freq = "M" )
41+ time_periods : list [str ] = [f"{ p .year } -{ p .month :02d} " for p in periods ]
42+ elif freq == "quarterly" :
43+ periods = pd .period_range (start = "2020-01-01" , periods = n_periods , freq = "Q-DEC" )
44+ time_periods = [f"{ p .year } -Q{ p .quarter } " for p in periods ]
45+ elif freq == "yearly" :
46+ periods = pd .period_range (start = "2020-01-01" , periods = n_periods , freq = "Y" )
47+ time_periods = [f"{ p .year } " for p in periods ]
48+ else :
49+ mes = "freq must be one of: 'monthly', 'quarterly', 'yearly'"
50+ raise ValueError (mes )
51+
52+ # Create product of industries and periods
6853 data = pd .DataFrame (
6954 [(id_company , period ) for id_company in company_ids for period in time_periods ],
7055 columns = ["id_company" , "time_period" ],
0 commit comments