Skip to content

Commit c676f00

Browse files
authored
Added check and stratification
Dev checks
2 parents 37ceb19 + 4b686ef commit c676f00

File tree

3 files changed

+202
-42
lines changed

3 files changed

+202
-42
lines changed

src/vaskify/createdata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,6 @@ def create_test_data(
8181
2,
8282
) # check if all get same random or not...
8383

84+
data["id_company"] = data["id_company"].astype(str)
85+
8486
return data

src/vaskify/detect.py

Lines changed: 167 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# %%
1111
import logging
12+
import re
1213

1314
import numpy as np
1415
import pandas as pd
@@ -31,9 +32,14 @@ def __init__(
3132
id_nr: String variable for the name of the variable to identify units with.
3233
logger_level: Detail level for information output. Choose between 'debug','info','warning','error' and 'critical'.
3334
"""
35+
# Check data
36+
self._check_data(data, id_nr=id_nr)
37+
38+
# Create self variables
3439
self.data = data
3540
self.id_nr = id_nr
3641

42+
# Start logging
3743
logging_dict = {
3844
"debug": 10,
3945
"info": 20,
@@ -52,6 +58,54 @@ def __init__(
5258
console_handler.setFormatter(formatter)
5359
self.logger.addHandler(console_handler)
5460

61+
@staticmethod
62+
def _check_data(
63+
data: pd.DataFrame,
64+
y_var: str = "",
65+
time_var: str = "",
66+
id_nr: str = "",
67+
) -> None:
68+
"""Check if the data contains the necessary columns, correct data types, and valid date format.
69+
70+
Args:
71+
data: The DataFrame to check.
72+
y_var: The variable of interest to check.
73+
time_var: String variable for indicating the time period.
74+
id_nr: String variable for the identifier.
75+
76+
Raises:
77+
ValueError: If any of the checks fail.
78+
"""
79+
required_columns = [y_var, time_var, id_nr]
80+
for col in required_columns:
81+
if col and col not in data.columns:
82+
mes = f"Missing column: {col}"
83+
raise ValueError(mes)
84+
if id_nr and not pd.api.types.is_string_dtype(data[id_nr]):
85+
mes = f"{id_nr} should be a string."
86+
raise ValueError(mes)
87+
88+
if y_var and not pd.api.types.is_numeric_dtype(data[y_var]):
89+
mes = f"{y_var} should be numeric."
90+
raise ValueError(mes)
91+
92+
if time_var:
93+
if not pd.api.types.is_string_dtype(data[time_var]):
94+
mes = f"{time_var} should be a string."
95+
raise ValueError(mes)
96+
97+
date_format_pattern = re.compile(
98+
r"^\d{4}(-\d{2}(-\d{2})?|-(Q[1-4]|W(0[1-9]|[1-4][0-9]|5[0-3]))|-\d{3)$",
99+
)
100+
101+
if (
102+
not data[time_var]
103+
.apply(lambda x: bool(date_format_pattern.match(x)))
104+
.all()
105+
):
106+
mes = f"{time_var} should be in the format 'YYYY', 'YYYY-Qq', 'YYYY-MM','YYYY-Www','YYYY-MM-DD', 'YYYY-DDD'."
107+
raise ValueError(mes)
108+
55109
def change_logging_level(self, logger_level: str) -> None:
56110
"""Change the logging print level.
57111
@@ -82,7 +136,7 @@ def thousand_error(
82136
83137
Args:
84138
y_var: The variable of insterest to check.
85-
time_var: String variable for indicating the time period. This should be in a standard format: 'YYYY', 'YYYY-Mm', 'YYYY-Kk'.
139+
time_var: String variable for indicating the time period. This should be in a ISO 8601 standard format for example: 'YYYY', 'YYYY-MM', 'YYYY-MM-DD' or a SSB standard like 'YYYY-Qq'.
86140
lower_bound: Float variable for the lower bound log factor for defining an outlier.
87141
upper_bound: Float variable for the upper bound log factor for defining an outlier.
88142
flag: String for the name of the flag variable to add to the data. Default is 'flag_thousand'.
@@ -93,13 +147,14 @@ def thousand_error(
93147
Returns:
94148
Data frame containing a flag variable for identified outliers or a dataframe containing only the outliers.
95149
"""
150+
# Check data
151+
self._check_data(self.data, y_var=y_var, time_var=time_var)
152+
96153
if (not impute_var) and (impute):
97154
impute_var = f"{y_var}_imputed"
98155
mes = f"No impute variable given so using {impute_var}"
99156
self.logger.info(mes)
100157

101-
# check data - add in
102-
103158
# Find differences by sorting first - not efficient but works
104159
data = self.data.sort_values(by=[self.id_nr, time_var]).reset_index(drop=True)
105160
log10_diff = data.groupby(self.id_nr)[y_var].transform(
@@ -130,7 +185,9 @@ def thousand_error(
130185
mask_outlier_units = data[self.id_nr].isin(outlier_ids)
131186
output = data.loc[mask_outlier_units, :]
132187
else:
133-
self.logger.warning("output_format is not valid. Use 'data' or 'outliers'")
188+
output = data
189+
mes = "output_format is not valid. Use 'data' or 'outliers'. Returning 'data' format."
190+
self.logger.warning(mes)
134191

135192
return output
136193

@@ -148,7 +205,7 @@ def accumulation_error(
148205
149206
Args:
150207
y_var: The variable of insterest to check.
151-
time_var: String variable for indicating the time period. This should be in a standard format: 'YYYY', 'YYYY-Mm', 'YYYY-Kk'.
208+
time_var: String variable for indicating the time period. This should be in a ISO 8601 standard format for example: 'YYYY', 'YYYY-MM', 'YYYY-MM-DD' or a SSB standard like 'YYYY-Qq'.
152209
error: Float for the allowed error factor.
153210
flag: String for the name of the flag variable to add to the data. Default is 'flag_thousand'.
154211
impute: Boolean for whether to impute the flagged observations. Default is False. (NOT IMPLEMENTED)
@@ -158,13 +215,14 @@ def accumulation_error(
158215
Returns:
159216
Data frame containing a flag variable for identified outliers or a dataframe containing only the outliers.
160217
"""
218+
# Check data
219+
self._check_data(self.data, y_var=y_var, time_var=time_var)
220+
161221
if (not impute_var) and (impute):
162222
impute_var = f"{y_var}_imputed"
163223
mes = f"No imputed variable name given so {impute_var} is being used"
164224
self.logger.info(mes)
165225

166-
# check data
167-
168226
# Sort and get previous period data
169227
data = self.data.sort_values(by=[self.id_nr, time_var]).reset_index(drop=True)
170228
expected_turnover = data.groupby(self.id_nr)[y_var].shift(1)
@@ -201,10 +259,48 @@ def accumulation_error(
201259

202260
return output
203261

262+
@staticmethod
263+
def _calculate_hb(
264+
x1: pd.Series, # type: ignore[type-arg]
265+
x2: pd.Series, # type: ignore[type-arg]
266+
pu: float,
267+
pa: float,
268+
pc: float,
269+
percentiles: tuple[float, float],
270+
) -> pd.DataFrame:
271+
"""Calculate HB method."""
272+
rat = x1 / x2
273+
med_ratio = rat.median()
274+
s_ratio = np.where(
275+
rat >= med_ratio,
276+
rat / med_ratio - 1,
277+
1 - med_ratio / rat,
278+
)
279+
280+
max_y = pd.concat([x1, x2], axis=1).max(axis=1)
281+
e_ratio = s_ratio * max_y**pu
282+
283+
e_ratio_q = e_ratio.quantile([percentiles[0], 0.5, percentiles[1]]).to_numpy()
284+
q1, q2, q3 = e_ratio_q
285+
286+
if q2 != 0:
287+
ell = q2 - pc * max(q2 - q1, abs(q2 * pa))
288+
eul = q2 + pc * max(q3 - q2, abs(q2 * pa))
289+
else:
290+
ell = q2 - pc * max(q2 - q1, pa)
291+
eul = q2 + pc * max(q3 - q2, pa)
292+
293+
lower_limit = med_ratio * max_y**pu / (max_y**pu - ell)
294+
upper_limit = med_ratio * (max_y**pu + eul) / max_y**pu
295+
296+
return pd.DataFrame({"lower_limit": lower_limit, "upper_limit": upper_limit})
297+
204298
def hb(
205299
self,
206300
y_var: str,
207301
time_var: str,
302+
time_periods: list[str] | None = None,
303+
strata_var: str = "",
208304
pu: float = 0.5,
209305
pa: float = 0.05,
210306
pc: float = 20,
@@ -218,74 +314,102 @@ def hb(
218314
219315
Args:
220316
y_var: String for the name of the variable of interest to check.
221-
time_var: String variable for indicating the time period. This should be in a standard format: 'YYYY', 'YYYY-Mm', 'YYYY-Kk'.
317+
time_var: String variable for indicating the time period. This should be in a ISO 8601 standard format for example: 'YYYY', 'YYYY-MM', 'YYYY-MM-DD' or a SSB standard like 'YYYY-Qq'.
318+
time_periods: List of strings for the two time periods to compare. Default None, in which case it is assumed that the time variable contains exactly two time preiods.
319+
strata_var: String variable for stratification. Default is blank ("").
222320
pu: Parameter that adjusts for different level of the variables. Default value 0.5.
223321
pa: Parameter that adjusts for small differences between the median and the 1st or 3rd quartile. Default value 0.05.
224-
pc: Parameter that controls the width of the confidence interval. Default value 4.
322+
pc: Parameter that controls the width of the confidence interval. Default value 20.
225323
percentiles: Tuple for percentile values to use.
226324
flag: String variable name to use to indicate outliers.
227325
output_format: String for format to return. Can be 'wide','long','outliers'.
228326
229327
Returns:
230328
Dataframe with flags or with identified units
231329
"""
232-
# check data ...
330+
# Check data
331+
self._check_data(self.data, y_var=y_var, time_var=time_var)
233332
data = self.data.copy()
234333

235-
# Get time levesl
334+
# Add in check if number of companies in each strata is too low.
335+
336+
# Filter time periods
337+
if time_periods:
338+
if len(time_periods) != 2:
339+
mes = "Two time periods should be specified."
340+
self.logger.error(mes)
341+
data = data.loc[data[time_var].isin(time_periods), :]
342+
343+
# Get time levels
236344
time_levels = np.unique(data[time_var])
237345
if len(time_levels) != 2:
238346
mes = "The time variable must have exactly two unique levels."
239347
self.logger.error(mes)
240-
x1 = time_levels[1] # t
241-
x2 = time_levels[0] # t-1
348+
time1 = time_levels[1] # t
349+
time0 = time_levels[0] # t-1
242350

243351
# Convert to wide
352+
wide_index = [self.id_nr, strata_var] if strata_var else self.id_nr
244353
wide_data = data.pivot_table(
245-
index=self.id_nr,
354+
index=wide_index,
246355
columns=time_var,
247356
values=y_var,
248357
aggfunc="first",
249358
).reset_index()
250359
wide_data.columns.name = None
251360

252361
# Check for valid rows
253-
valid_rows = wide_data[(wide_data[x1] > 0) & (wide_data[x2] > 0)]
362+
valid_rows = wide_data[(wide_data[time1] > 0) & (wide_data[time0] > 0)]
254363
if valid_rows.empty:
255364
mes = "No valid rows with y_var > 0 for both time periods."
256365
self.logger.error(mes)
257366

258-
# Calculate the ratio and related metrics
259-
valid_rows["ratio"] = valid_rows[x1] / valid_rows[x2]
260-
med_ratio = valid_rows["ratio"].median()
261-
s_ratio = np.where(
262-
valid_rows["ratio"] >= med_ratio,
263-
valid_rows["ratio"] / med_ratio - 1,
264-
1 - med_ratio / valid_rows["ratio"],
265-
)
266-
267-
max_y = valid_rows[[x1, x2]].max(axis=1)
268-
e_ratio = s_ratio * max_y**pu
367+
# Add in ratio
368+
valid_rows["ratio"] = valid_rows[time1] / valid_rows[time0]
369+
370+
# Apply the HB function to each strata group
371+
if strata_var:
372+
limits = (
373+
valid_rows.groupby(strata_var)
374+
.apply(
375+
lambda group: self._calculate_hb(
376+
group[time1],
377+
group[time0],
378+
pu,
379+
pa,
380+
pc,
381+
percentiles,
382+
),
383+
)
384+
.reset_index(level=strata_var, drop=True)
385+
)
386+
else:
387+
limits = self._calculate_hb(
388+
valid_rows[time1],
389+
valid_rows[time0],
390+
pu,
391+
pa,
392+
pc,
393+
percentiles,
394+
)
269395

270-
# Compute quantiles for e ratio
271-
percentiles = (0.25, 0.75) # Can also be 0.1, 0.9
272-
e_ratio_q = e_ratio.quantile([percentiles[0], 0.5, percentiles[1]]).to_numpy()
273-
q1, q2, q3 = e_ratio_q
396+
# Merge the limits back into the valid_rows
397+
valid_rows = valid_rows.merge(
398+
limits,
399+
left_index=True,
400+
right_index=True,
401+
how="left",
402+
)
274403

275-
if q2 != 0:
276-
ell = q2 - pc * max(q2 - q1, abs(q2 * pa))
277-
eul = q2 + pc * max(q3 - q2, abs(q2 * pa))
278-
else:
279-
ell = q2 - pc * max(q2 - q1, pa)
280-
eul = q2 + pc * max(q3 - q2, pa)
281-
valid_rows["lower_limit"] = med_ratio * max_y**pu / (max_y**pu - ell)
282-
valid_rows["upper_limit"] = med_ratio * (max_y**pu + eul) / max_y**pu
404+
# Add in flag
283405
valid_rows[flag] = np.where(
284406
(valid_rows["ratio"] < valid_rows["lower_limit"])
285407
| (valid_rows["ratio"] > valid_rows["upper_limit"]),
286408
1,
287409
0,
288410
)
411+
412+
# Format in correct output format
289413
if output_format == "wide":
290414
output: pd.DataFrame = valid_rows
291415
elif output_format == "outliers":
@@ -300,10 +424,11 @@ def hb(
300424
var_name=time_var,
301425
value_name=y_var,
302426
)
303-
# Add in NAs for first time period here ...
427+
mask = output[time_var] == time_levels[0]
428+
output.loc[mask, ["lower_limit", "upper_limit", flag]] = np.nan
304429
else:
305-
self.logger.warning(
306-
"output_format is not valid. Use 'wide' or 'outliers' or 'long'",
307-
)
430+
mes = "output_format is not valid. Use 'wide', 'outliers' or 'long'. Wide being returned."
431+
self.logger.warning(mes)
432+
output = valid_rows
308433

309434
return output

tests/test_detect.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,39 @@ def test_hb() -> None:
6767
assert dt_controlled.shape[0] == expected_shape, "Long format returned"
6868

6969

70+
def test_hb_strata() -> None:
71+
dt = create_test_data(n=50, seed=10)
72+
dt2 = dt.loc[dt.time_period.isin(["2020-04", "2020-05"]), :]
73+
74+
detect = Detect(dt2, id_nr="id_company")
75+
dt_controlled = detect.hb(
76+
y_var="turnover",
77+
time_var="time_period",
78+
strata_var="nace",
79+
)
80+
81+
assert any(dt_controlled.columns.isin(["flag_hb"])), "Flag variable created"
82+
expected_shape = 50
83+
assert dt_controlled.shape[0] == expected_shape, "Wide format returned as default"
84+
85+
dt_controlled = detect.hb(
86+
y_var="turnover",
87+
strata_var="nace",
88+
time_var="time_period",
89+
output_format="outliers",
90+
)
91+
expected_shape = 2
92+
assert dt_controlled.shape[0] == expected_shape, "Oulier format returned"
93+
94+
dt_controlled = detect.hb(
95+
y_var="turnover",
96+
time_var="time_period",
97+
output_format="long",
98+
)
99+
expected_shape = 100
100+
assert dt_controlled.shape[0] == expected_shape, "Long format returned"
101+
102+
70103
# %%
71104
def test_logger() -> None:
72105
dt = create_test_data(n=5, n_periods=2, freq="monthly", seed=42)

0 commit comments

Comments
 (0)