Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions forecasttools/daily_to_epiweekly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,50 @@
to epiweekly dataframes.
"""

from datetime import datetime
import datetime

import epiweeks
import polars as pl

from forecasttools.utils import ensure_listlike


def calculate_epi_week_and_year(date: str):
"""
Converts an ISO8601 formatted
date into an epiweek and epiyear.
def calculate_epi_week_and_year(date_input: datetime.date) -> dict[str, int]:
"""Convert a date to epidemiological week and year."""
if not isinstance(date_input, datetime.date):
raise TypeError(
f"date_input must be a datetime.date, got {type(date_input).__name__}"
)

date
An ISO8601 date.
"""
epiweek = epiweeks.Week.fromdate(datetime.strptime(date, "%Y-%m-%d"))
epiweek = epiweeks.Week.fromdate(date_input)
epiweek_df_struct = {
"epiweek": epiweek.week,
"epiyear": epiweek.year,
}
return epiweek_df_struct


def calculate_epiweek_enddate(epiyear: int, epiweek: int) -> datetime.date:
"""
Given an epiweek and epiyear, return
the enddate (Saturday) of that epiweek.

epiyear
Epidemiological year.
epiweek
Epidemiological week number.
"""
return epiweeks.Week(epiyear, epiweek).enddate()


def df_aggregate_to_epiweekly(
df: pl.DataFrame,
value_col: str = "value",
date_col: str = "date",
id_cols: list[str] = None,
weekly_value_name: str = "weekly_value",
with_epiweek_end_date: bool = False,
epiweek_end_date_name: str = "epiweek_end_date",
strict: bool = True,
) -> pl.DataFrame:
"""
Expand Down Expand Up @@ -60,6 +74,12 @@ def df_aggregate_to_epiweekly(
The name to use for the output column
containing weekly trajectory values.
Defaults to ``"weekly_value"``.
with_epiweek_end_date
Whether to annotate output with the last date
of each epiweek. Defaults to ``False``.
epiweek_end_date_name
Name for the output epiweek end-date column.
Defaults to ``"epiweek_end_date"``.
strict
Whether to aggregate to epiweekly only
for weeks in which all seven days have
Expand Down Expand Up @@ -125,4 +145,18 @@ def df_aggregate_to_epiweekly(
.agg(pl.col(value_col).sum().alias(weekly_value_name))
.sort(group_cols)
)

if with_epiweek_end_date:
df = df.with_columns(
pl.struct(["epiyear", "epiweek"])
.map_elements(
lambda elt: calculate_epiweek_enddate(
epiyear=elt["epiyear"],
epiweek=elt["epiweek"],
),
return_dtype=pl.Date,
)
.alias(epiweek_end_date_name)
)

return df
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test = [
Repository = "https://github.com/CDCgov/forecasttools-py/"
"Repository Issues" = "https://github.com/CDCgov/forecasttools-py/issues"
"CDCgov Repositories" = "https://github.com/CDCgov"
"Package That Will Use Forecasttools" = "https://github.com/CDCgov/pyrenew-hew"
"Package That Will Use Forecasttools" = "https://github.com/CDCgov/cfa-stf-routine-forecasting"
"Poetry Pyproject Page" = "https://python-poetry.org/docs/pyproject/"

[tool.ruff.lint.mccabe]
Expand Down
38 changes: 38 additions & 0 deletions tests/test_daily_to_epiweekly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Tests for daily_to_epiweekly.py functions.
"""

import datetime

import polars as pl

import forecasttools


def test_df_aggregate_to_epiweekly_adds_epiweek_end_date():
"""
Test aggregated output includes Saturday end date
for each epiweek.
"""
df = pl.DataFrame(
{
"location": ["TX"] * 8,
"date": [
datetime.date(2023, 10, 8) + datetime.timedelta(days=i)
for i in range(8)
],
"value": [1, 1, 1, 1, 1, 1, 1, 4],
}
)

out = forecasttools.df_aggregate_to_epiweekly(
df=df,
id_cols=["location"],
with_epiweek_end_date=True,
)

assert "epiweek_end_date" in out.columns

assert out.get_column("epiweek_end_date").item() == datetime.date(2023, 10, 14)

assert out.get_column("weekly_value").item() == 7