Skip to content

Commit 481fe8c

Browse files
committed
Fix for mypy and documentation
1 parent 9170b2d commit 481fe8c

File tree

1 file changed

+18
-33
lines changed

1 file changed

+18
-33
lines changed

src/vaskify/createdata.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def create_test_data(
2222
2323
Returns:
2424
pd.DataFrame: Test data in long format.
25+
26+
Raises:
27+
ValueError: If `freq` is not one of "monthly", "quarterly", or "yearly".
28+
2529
"""
2630
rng = np.random.default_rng(seed) if seed else np.random.default_rng()
2731

@@ -31,40 +35,21 @@ def create_test_data(
3135
industry_codes = ["B", "C", "F", "G", "H", "J", "M", "N", "S"]
3236
industries = rng.choice(industry_codes, size=n, replace=True)
3337

34-
# Generate time periods
38+
# Generate time periods as List[str]
3539
if freq == "monthly":
36-
time_periods = (
37-
pd.date_range(
38-
start="2020-01-01",
39-
periods=n_periods,
40-
freq="ME",
41-
)
42-
.to_period("M")
43-
.astype(str)
44-
)
45-
if freq == "quarterly":
46-
time_periods_str = (
47-
pd.date_range(
48-
start="2020-01-01",
49-
periods=n_periods,
50-
freq="QE",
51-
)
52-
.to_period("Q")
53-
.astype(str)
54-
)
55-
time_periods = [f"{p[:4]}-Q{p[5:]}" for p in time_periods_str.tolist()]
56-
if freq == "yearly":
57-
time_periods = (
58-
pd.date_range(
59-
start="2020-01-01",
60-
periods=n_periods,
61-
freq="YE",
62-
)
63-
.to_period("Y")
64-
.astype(str)
65-
)
66-
67-
# Create Cartesian product of industries and periods
40+
periods = pd.period_range(start="2020-01-01", periods=n_periods, freq="M")
41+
time_periods: list[str] = [f"{p.year}-{p.month:02d}" for p in periods]
42+
elif freq == "quarterly":
43+
periods = pd.period_range(start="2020-01-01", periods=n_periods, freq="Q-DEC")
44+
time_periods = [f"{p.year}-Q{p.quarter}" for p in periods]
45+
elif freq == "yearly":
46+
periods = pd.period_range(start="2020-01-01", periods=n_periods, freq="Y")
47+
time_periods = [f"{p.year}" for p in periods]
48+
else:
49+
mes = "freq must be one of: 'monthly', 'quarterly', 'yearly'"
50+
raise ValueError(mes)
51+
52+
# Create product of industries and periods
6853
data = pd.DataFrame(
6954
[(id_company, period) for id_company in company_ids for period in time_periods],
7055
columns=["id_company", "time_period"],

0 commit comments

Comments
 (0)