Skip to content

Commit a277bb1

Browse files
committed
Added nerc_invoicing repo for querying invoicing Iceberg table
Added functionality to get lifetime costs grouped by project Added pyproject.toml for future publishing Added testing for new functionality
1 parent 4920ec7 commit a277bb1

6 files changed

Lines changed: 390 additions & 0 deletions

File tree

process_report/data_tools/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
3+
from pydantic_settings import BaseSettings
4+
from pyiceberg.table import StaticTable
5+
6+
7+
class DataToolsSettings(BaseSettings):
8+
"""Iceberg warehouse path and S3 credentials for data_tools queries."""
9+
10+
iceberg_warehouse_base: str = "s3://nerc-invoicing-iceberg/warehouse"
11+
iceberg_table_subpath: str = "nerc_invoicing_iceberg/nerc_invoicing_iceberg"
12+
iceberg_s3_access_key_id: str | None = None
13+
iceberg_s3_secret_access_key: str | None = None
14+
iceberg_s3_endpoint: str | None = None
15+
iceberg_s3_region: str = "us-east-005"
16+
17+
@property
18+
def table_path(self) -> str:
19+
return f"{self.iceberg_warehouse_base}/{self.iceberg_table_subpath}"
20+
21+
def iceberg_s3_properties(self) -> dict[str, str]:
22+
if not all(
23+
[
24+
self.iceberg_s3_access_key_id,
25+
self.iceberg_s3_secret_access_key,
26+
self.iceberg_s3_endpoint,
27+
]
28+
):
29+
raise ValueError(
30+
"Iceberg S3 credentials required: "
31+
"ICEBERG_S3_ACCESS_KEY_ID, ICEBERG_S3_SECRET_ACCESS_KEY, ICEBERG_S3_ENDPOINT"
32+
)
33+
return {
34+
"s3.access-key-id": self.iceberg_s3_access_key_id,
35+
"s3.secret-access-key": self.iceberg_s3_secret_access_key,
36+
"s3.endpoint": f"https://{self.iceberg_s3_endpoint}",
37+
"s3.region": self.iceberg_s3_region,
38+
}
39+
40+
41+
data_tools_settings = DataToolsSettings()
42+
43+
44+
@functools.cache
45+
def get_table() -> StaticTable:
46+
return StaticTable.from_metadata(
47+
data_tools_settings.table_path,
48+
properties=data_tools_settings.iceberg_s3_properties(),
49+
)

process_report/data_tools/costs.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import functools
2+
from decimal import Decimal
3+
import logging
4+
5+
import pandas as pd
6+
from pyiceberg.expressions import And, BooleanExpression, EqualTo
7+
8+
import process_report.invoices.invoice as invoice
9+
from process_report.data_tools.config import get_table
10+
11+
logger = logging.getLogger(__name__)
12+
FilterValue = str | int | float
13+
14+
_LIFETIME_COLS = [
15+
invoice.PROJECT_ID_FIELD,
16+
invoice.CLUSTER_NAME_FIELD,
17+
invoice.BALANCE_FIELD,
18+
]
19+
20+
21+
def _row_filter(**filters: FilterValue) -> BooleanExpression | None:
22+
"""Build a PyIceberg row filter expression from column=value filters.
23+
24+
Args:
25+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
26+
27+
Returns:
28+
PyIceberg BooleanExpression like EqualTo(col1, 'x') AND EqualTo(col2, 1),
29+
or None if no filters are given.
30+
"""
31+
if not filters:
32+
return None
33+
expression: BooleanExpression | None = None
34+
for col, val in filters.items():
35+
clause = EqualTo(col, val)
36+
expression = clause if expression is None else And(expression, clause)
37+
return expression
38+
39+
40+
@functools.cache
41+
def get_invoice_dataframe(
42+
cols: tuple[str, ...] | None = None, **filters: FilterValue
43+
) -> pd.DataFrame:
44+
"""Load invoice data from the Iceberg table.
45+
46+
Args:
47+
cols: Column names to select as a tuple. None selects all columns.
48+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
49+
50+
Returns:
51+
DataFrame of invoice data from the table.
52+
"""
53+
table = get_table()
54+
row_filter = _row_filter(**filters)
55+
if row_filter:
56+
scan = table.scan(row_filter=row_filter)
57+
else:
58+
scan = table.scan()
59+
if cols:
60+
scan = scan.select(*cols)
61+
df = scan.to_pandas()
62+
if filters and df.empty:
63+
logger.warning("No invoice rows matched filters: %s", filters)
64+
return df
65+
66+
67+
def group_and_sum(
68+
df: pd.DataFrame,
69+
group_by: tuple[str, ...],
70+
*,
71+
agg_col: str,
72+
agg_name: str = "total",
73+
) -> pd.DataFrame:
74+
"""Group a dataframe and aggregate one column with sum.
75+
76+
Args:
77+
df: Input dataframe.
78+
group_by: Column names to group by.
79+
agg_col: Column to sum.
80+
agg_name: Name for the aggregated column in the output. Defaults to "total".
81+
82+
Returns:
83+
DataFrame with one row per group and a column containing the sum of agg_col.
84+
"""
85+
grouped_input = df.copy()
86+
grouped_input[agg_col] = grouped_input[agg_col].fillna(0)
87+
agg_spec = {agg_name: (agg_col, "sum")}
88+
grouped_df = grouped_input.groupby(list(group_by), as_index=False).agg(**agg_spec)
89+
grouped_df[agg_name] = grouped_df[agg_name].map(
90+
lambda v: Decimal(str(v)).quantize(Decimal("0.01"))
91+
)
92+
return grouped_df
93+
94+
95+
def aggregate_by(
96+
cols: tuple[str, ...],
97+
group_by: tuple[str, ...],
98+
*,
99+
agg_col: str,
100+
agg_name: str = "total",
101+
**filters: FilterValue,
102+
) -> pd.DataFrame:
103+
"""Load invoice data and return grouped sum totals.
104+
105+
This helper fetches invoice rows using the provided selected columns and filters,
106+
ensures grouping columns are included in the selection, then returns a grouped sum
107+
aggregation over ``agg_col``.
108+
109+
Args:
110+
cols: Columns to select from the invoice table before aggregation.
111+
group_by: Columns to group rows by in the aggregated output.
112+
agg_col: Numeric column to sum within each group.
113+
agg_name: Output column name for the aggregated sum. Defaults to ``"total"``.
114+
**filters: Column=value equality filters applied while loading invoice data.
115+
Values must be str, int, or float.
116+
117+
Returns:
118+
DataFrame with one row per unique ``group_by`` combination and a summed
119+
``agg_name`` column quantized to two decimal places.
120+
121+
Example:
122+
>>> df = aggregate_by(
123+
... cols=(invoice.BALANCE_FIELD,),
124+
... group_by=(invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD),
125+
... agg_col=invoice.BALANCE_FIELD,
126+
... agg_name="lifetime_allocation_balance",
127+
... )
128+
"""
129+
all_cols = list(cols)
130+
for col in group_by:
131+
if col not in all_cols:
132+
all_cols.append(col)
133+
df = get_invoice_dataframe(tuple(all_cols), **filters)
134+
return group_and_sum(
135+
df,
136+
group_by=group_by,
137+
agg_col=agg_col,
138+
agg_name=agg_name,
139+
)
140+
141+
142+
def calculate_lifetime_costs(**filters: FilterValue) -> pd.DataFrame:
143+
"""Group invoice data by project and cluster, summing balance per group.
144+
145+
Args:
146+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
147+
148+
Returns:
149+
DataFrame with columns: Project - Allocation, Cluster Name, lifetime_allocation_balance.
150+
151+
Example:
152+
>>> filters = {invoice.PROJECT_ID_FIELD: "vllm-test"}
153+
>>> df = calculate_lifetime_costs(**filters)
154+
"""
155+
156+
return aggregate_by(
157+
tuple(_LIFETIME_COLS),
158+
(invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD),
159+
agg_col=invoice.BALANCE_FIELD,
160+
agg_name="lifetime_allocation_balance",
161+
**filters,
162+
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from process_report.data_tools import costs
5+
6+
PID = costs.invoice.PROJECT_ID_FIELD
7+
CLUSTER = costs.invoice.CLUSTER_NAME_FIELD
8+
BALANCE = costs.invoice.BALANCE_FIELD
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def clear_dataframe_cache():
13+
costs.get_invoice_dataframe.cache_clear()
14+
yield
15+
costs.get_invoice_dataframe.cache_clear()
16+
17+
18+
@pytest.fixture
19+
def sample_invoice_dataframe() -> pd.DataFrame:
20+
return pd.DataFrame(
21+
{
22+
PID: ["vllm-test", "vllm-test", "webrca-1b021a"],
23+
CLUSTER: ["ocp-test", "ocp-test", "ocp-prod"],
24+
BALANCE: [1.234, 2.345, None],
25+
}
26+
)
27+
28+
29+
def test_row_filter_empty_returns_none():
30+
assert costs._row_filter() is None
31+
32+
33+
@pytest.mark.parametrize(
34+
"filters",
35+
[
36+
{PID: "vllm-test", CLUSTER: "ocp-test"},
37+
{PID: "vllm-test", CLUSTER: "ocp-prod"},
38+
],
39+
)
40+
def test_row_filter_builds_combined_and_expression(filters: dict[str, str]):
41+
expression = costs._row_filter(**filters)
42+
assert isinstance(expression, costs.And)
43+
assert isinstance(expression.left, costs.EqualTo)
44+
assert isinstance(expression.right, costs.EqualTo)
45+
46+
47+
def test_aggregate_by_rounds_and_forwards_filters(
48+
monkeypatch: pytest.MonkeyPatch, sample_invoice_dataframe: pd.DataFrame
49+
):
50+
captured: dict[str, object] = {}
51+
52+
def _fake_loader(cols=None, **filters):
53+
captured["cols"] = cols
54+
captured["filters"] = filters
55+
return sample_invoice_dataframe
56+
57+
monkeypatch.setattr(costs, "get_invoice_dataframe", _fake_loader)
58+
59+
result = costs.aggregate_by(
60+
(BALANCE,),
61+
(PID, CLUSTER),
62+
agg_col=BALANCE,
63+
agg_name="lifetime_allocation_balance",
64+
**{PID: "vllm-test"},
65+
)
66+
67+
assert captured["filters"] == {PID: "vllm-test"}
68+
assert captured["cols"] == (BALANCE, PID, CLUSTER)
69+
70+
values = sorted(result["lifetime_allocation_balance"].tolist())
71+
assert values == [costs.Decimal("0.00"), costs.Decimal("3.58")]
72+
assert all(v.as_tuple().exponent == -2 for v in values)
73+
74+
75+
def test_group_and_sum_is_pure_transform(sample_invoice_dataframe: pd.DataFrame):
76+
result = costs.group_and_sum(
77+
sample_invoice_dataframe,
78+
(PID, CLUSTER),
79+
agg_col=BALANCE,
80+
agg_name="lifetime_allocation_balance",
81+
)
82+
83+
values = sorted(result["lifetime_allocation_balance"].tolist())
84+
assert values == [costs.Decimal("0.00"), costs.Decimal("3.58")]
85+
assert all(v.as_tuple().exponent == -2 for v in values)
86+
87+
88+
@pytest.mark.parametrize(
89+
"invalid_filters",
90+
[
91+
{PID: "does-not-exist"},
92+
{CLUSTER: "not-a-real-cluster"},
93+
{PID: "does-not-exist", CLUSTER: "not-a-real-cluster"},
94+
],
95+
)
96+
def test_calculate_lifetime_costs_invalid_queries_return_empty(
97+
monkeypatch: pytest.MonkeyPatch, invalid_filters: dict[str, str]
98+
):
99+
empty_df = pd.DataFrame(columns=[PID, CLUSTER, BALANCE])
100+
monkeypatch.setattr(costs, "get_invoice_dataframe", lambda cols=None, **f: empty_df)
101+
102+
result = costs.calculate_lifetime_costs(**invalid_filters)
103+
104+
assert result.empty
105+
assert result.columns.tolist() == [PID, CLUSTER, "lifetime_allocation_balance"]
106+
107+
108+
class _FakeIcebergTable:
109+
"""Responds to .scan().select().to_pandas() chains."""
110+
111+
def __init__(self, df: pd.DataFrame):
112+
self._df = df
113+
114+
def scan(self, row_filter=None):
115+
return self
116+
117+
def select(self, *cols):
118+
return self
119+
120+
def to_pandas(self):
121+
return self._df
122+
123+
124+
def test_get_invoice_dataframe_warns_when_no_rows_match(
125+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
126+
):
127+
table = _FakeIcebergTable(pd.DataFrame(columns=[PID, BALANCE]))
128+
monkeypatch.setattr(costs, "get_table", lambda: table)
129+
130+
with caplog.at_level("WARNING", logger=costs.__name__):
131+
result = costs.get_invoice_dataframe((PID, BALANCE), **{PID: "does-not-exist"})
132+
133+
assert result.empty
134+
assert "No invoice rows matched filters" in caplog.text
135+
136+
137+
def test_get_invoice_dataframe_caches_repeated_query(monkeypatch: pytest.MonkeyPatch):
138+
table = _FakeIcebergTable(pd.DataFrame({PID: ["vllm-test"], BALANCE: [1.0]}))
139+
call_counter = {"count": 0}
140+
141+
def _fake_get_table():
142+
call_counter["count"] += 1
143+
return table
144+
145+
monkeypatch.setattr(costs, "get_table", _fake_get_table)
146+
147+
first = costs.get_invoice_dataframe((PID, BALANCE), **{PID: "vllm-test"})
148+
second = costs.get_invoice_dataframe((PID, BALANCE), **{PID: "vllm-test"})
149+
150+
assert call_counter["count"] == 1
151+
assert first is second

pyproject.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[project]
2+
name = "nerc-invoicing"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"nerc-rates>=1.0.1,<2.0.0",
9+
"pandas>=3.0.0",
10+
"pyarrow",
11+
"pyiceberg[pyarrow]>=0.11.0",
12+
"boto3>=1.42.6,<2.0",
13+
"jinja2",
14+
"validators",
15+
"python-dateutil",
16+
"pydantic-settings",
17+
"pyyaml>=6.0",
18+
"pre-commit>=4.5.1",
19+
]
20+
21+
[build-system]
22+
requires = ["uv_build>=0.10.4,<0.11.0"]
23+
build-backend = "uv_build"
24+
25+
[tool.uv.build-backend]
26+
module-name = "process_report"
27+
module-root = ""

0 commit comments

Comments
 (0)