Skip to content

Commit 73b5ca0

Browse files
committed
Closes #266
Added nerc_invoicing repo for querying invoicing Iceberg table Added functionality to get lifetime costs grouped by project Added pyproject.toml for future publishing Added testing for new functionality
1 parent 4920ec7 commit 73b5ca0

7 files changed

Lines changed: 289 additions & 1 deletion

File tree

.github/ISSUE_TEMPLATE/custom.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ assignees: ''
1212
## Completion Criteria
1313

1414
## Description
15-
- [ ]
15+
- [ ]
1616

1717
## Completion dates
1818
Desired - YYYY-MM-DD

data_tools/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .costs import calculate_lifetime_costs
2+
3+
__all__ = ["calculate_lifetime_costs"]

data_tools/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
from pyiceberg.table import StaticTable
3+
4+
WAREHOUSE_BASE = "s3://nerc-invoicing-iceberg/warehouse"
5+
TABLE_PATH = f"{WAREHOUSE_BASE}/nerc_invoicing_iceberg/nerc_invoicing_iceberg"
6+
7+
8+
def _b2_properties() -> dict[str, str]:
9+
return {
10+
"s3.access-key-id": os.environ["B2_APPLICATION_KEY_ID"],
11+
"s3.secret-access-key": os.environ["B2_APPLICATION_KEY"],
12+
"s3.endpoint": f"https://{os.environ['B2_S3_ENDPOINT']}",
13+
"s3.region": "us-east-005",
14+
}
15+
16+
17+
def _get_table() -> StaticTable:
18+
return StaticTable.from_metadata(
19+
TABLE_PATH,
20+
properties=_b2_properties(),
21+
)

data_tools/costs.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from decimal import Decimal
2+
import logging
3+
4+
import pandas as pd
5+
from pyiceberg.expressions import And, BooleanExpression, EqualTo
6+
7+
import process_report.invoices.invoice as invoice
8+
9+
from .config import _get_table
10+
11+
logger = logging.getLogger(__name__)
12+
13+
_LIFETIME_COLS = [
14+
invoice.PROJECT_ID_FIELD,
15+
invoice.CLUSTER_NAME_FIELD,
16+
invoice.BALANCE_FIELD,
17+
]
18+
19+
20+
def _row_filter(**filters: str | int | float) -> BooleanExpression | None:
21+
"""Build a PyIceberg row filter expression from column=value filters.
22+
23+
Args:
24+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
25+
26+
Returns:
27+
PyIceberg BooleanExpression like EqualTo(col1, 'x') AND EqualTo(col2, 1),
28+
or None if no filters are given.
29+
"""
30+
if not filters:
31+
return None
32+
expression: BooleanExpression | None = None
33+
for col, val in filters.items():
34+
clause = EqualTo(col, val)
35+
expression = clause if expression is None else And(expression, clause)
36+
return expression
37+
38+
39+
def get_invoice_dataframe(
40+
cols: list[str] | None = None, **filters: str | int | float
41+
) -> pd.DataFrame:
42+
"""Load invoice data from the Iceberg table.
43+
44+
Args:
45+
cols: Column names to select. None selects all columns.
46+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
47+
48+
Returns:
49+
DataFrame of invoice data from the table.
50+
"""
51+
table = _get_table()
52+
row_filter = _row_filter(**filters)
53+
if row_filter:
54+
scan = table.scan(row_filter=row_filter)
55+
else:
56+
scan = table.scan()
57+
if cols:
58+
scan = scan.select(*cols)
59+
df = scan.to_pandas()
60+
if filters and df.empty:
61+
logger.warning("No invoice rows matched filters: %s", filters)
62+
return df
63+
64+
65+
def select_and_group(
66+
cols: list[str],
67+
group_by: list[str],
68+
*,
69+
agg_col: str,
70+
agg_name: str = "total",
71+
**filters: str | int | float,
72+
) -> pd.DataFrame:
73+
"""Load invoice data, group by the given columns, and aggregate one column with sum.
74+
75+
Args:
76+
cols: Column names to load from the table.
77+
group_by: Column names to group by.
78+
agg_col: Column to sum.
79+
agg_name: Name for the aggregated column in the output. Defaults to "total".
80+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
81+
82+
Returns:
83+
DataFrame with one row per group and a column containing the sum of agg_col.
84+
"""
85+
all_cols = list(cols)
86+
for c in group_by:
87+
if c not in all_cols:
88+
all_cols.append(c)
89+
df = get_invoice_dataframe(all_cols, **filters)
90+
df[agg_col] = df[agg_col].fillna(0)
91+
agg_spec = {agg_name: (agg_col, "sum")}
92+
grouped_df = df.groupby(list(group_by), as_index=False).agg(**agg_spec)
93+
grouped_df[agg_name] = grouped_df[agg_name].map(
94+
lambda v: Decimal(str(v)).quantize(Decimal("0.01"))
95+
)
96+
return grouped_df
97+
98+
99+
def calculate_lifetime_costs(**filters: str | int | float) -> pd.DataFrame:
100+
"""Group invoice data by project and cluster, summing balance per group.
101+
102+
Args:
103+
**filters: Column names as keys, values to filter by. Values must be str, int, or float.
104+
105+
Returns:
106+
DataFrame with columns: Project - Allocation, Cluster Name, lifetime_allocation_balance.
107+
"""
108+
return select_and_group(
109+
_LIFETIME_COLS,
110+
[invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD],
111+
agg_col=invoice.BALANCE_FIELD,
112+
agg_name="lifetime_allocation_balance",
113+
**filters,
114+
)
115+
116+
117+
if __name__ == "__main__":
118+
print(calculate_lifetime_costs())
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from data_tools import costs
5+
6+
PID = costs.invoice.PROJECT_ID_FIELD
7+
CLUSTER = costs.invoice.CLUSTER_NAME_FIELD
8+
BALANCE = costs.invoice.BALANCE_FIELD
9+
10+
11+
@pytest.fixture
12+
def sample_invoice_dataframe() -> pd.DataFrame:
13+
return pd.DataFrame(
14+
{
15+
PID: ["vllm-test", "vllm-test", "webrca-1b021a"],
16+
CLUSTER: ["ocp-test", "ocp-test", "ocp-prod"],
17+
BALANCE: [1.234, 2.345, None],
18+
}
19+
)
20+
21+
22+
def test_row_filter_empty_returns_none():
23+
assert costs._row_filter() is None
24+
25+
26+
@pytest.mark.parametrize(
27+
"filters",
28+
[
29+
{PID: "vllm-test", CLUSTER: "ocp-test"},
30+
{PID: "vllm-test", CLUSTER: "ocp-prod"},
31+
],
32+
)
33+
def test_row_filter_builds_combined_and_expression(filters: dict[str, str]):
34+
expression = costs._row_filter(**filters)
35+
assert isinstance(expression, costs.And)
36+
assert isinstance(expression.left, costs.EqualTo)
37+
assert isinstance(expression.right, costs.EqualTo)
38+
39+
40+
def test_select_and_group_rounds_and_forwards_filters(
41+
monkeypatch: pytest.MonkeyPatch, sample_invoice_dataframe: pd.DataFrame
42+
):
43+
captured: dict[str, object] = {}
44+
45+
def _fake_loader(cols=None, **filters):
46+
captured["cols"] = cols
47+
captured["filters"] = filters
48+
return sample_invoice_dataframe
49+
50+
monkeypatch.setattr(costs, "get_invoice_dataframe", _fake_loader)
51+
52+
result = costs.select_and_group(
53+
[BALANCE],
54+
[PID, CLUSTER],
55+
agg_col=BALANCE,
56+
agg_name="lifetime_allocation_balance",
57+
**{PID: "vllm-test"},
58+
)
59+
60+
assert captured["filters"] == {PID: "vllm-test"}
61+
assert captured["cols"] == [BALANCE, PID, CLUSTER]
62+
63+
values = sorted(result["lifetime_allocation_balance"].tolist())
64+
assert values == [costs.Decimal("0.00"), costs.Decimal("3.58")]
65+
assert all(v.as_tuple().exponent == -2 for v in values)
66+
67+
68+
@pytest.mark.parametrize(
69+
"invalid_filters",
70+
[
71+
{PID: "does-not-exist"},
72+
{CLUSTER: "not-a-real-cluster"},
73+
{PID: "does-not-exist", CLUSTER: "not-a-real-cluster"},
74+
],
75+
)
76+
def test_calculate_lifetime_costs_invalid_queries_return_empty(
77+
monkeypatch: pytest.MonkeyPatch, invalid_filters: dict[str, str]
78+
):
79+
empty_df = pd.DataFrame(columns=[PID, CLUSTER, BALANCE])
80+
monkeypatch.setattr(costs, "get_invoice_dataframe", lambda cols=None, **f: empty_df)
81+
82+
result = costs.calculate_lifetime_costs(**invalid_filters)
83+
84+
assert result.empty
85+
assert result.columns.tolist() == [PID, CLUSTER, "lifetime_allocation_balance"]
86+
87+
88+
class _FakeIcebergTable:
89+
"""Responds to .scan().select().to_pandas() chains."""
90+
91+
def __init__(self, df: pd.DataFrame):
92+
self._df = df
93+
94+
def scan(self, row_filter=None):
95+
return self
96+
97+
def select(self, *cols):
98+
return self
99+
100+
def to_pandas(self):
101+
return self._df
102+
103+
104+
def test_get_invoice_dataframe_warns_when_no_rows_match(
105+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
106+
):
107+
table = _FakeIcebergTable(pd.DataFrame(columns=[PID, BALANCE]))
108+
monkeypatch.setattr(costs, "_get_table", lambda: table)
109+
110+
with caplog.at_level("WARNING", logger="data_tools.costs"):
111+
result = costs.get_invoice_dataframe([PID, BALANCE], **{PID: "does-not-exist"})
112+
113+
assert result.empty
114+
assert "No invoice rows matched filters" in caplog.text

pyproject.toml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
[project]
2+
name = "nerc-invoicing"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"nerc-rates>=1.0.1,<2.0.0",
9+
"pandas>=3.0.0",
10+
"pyarrow",
11+
"pyiceberg[pyarrow]>=0.11.0",
12+
"boto3>=1.42.6,<2.0",
13+
"jinja2",
14+
"validators",
15+
"python-dateutil",
16+
"pydantic-settings",
17+
"pyyaml>=6.0",
18+
"pre-commit>=4.5.1",
19+
]
20+
21+
[build-system]
22+
requires = ["hatchling"]
23+
build-backend = "hatchling.build"
24+
25+
[tool.hatch.build.targets.wheel]
26+
packages = ["data_tools", "process_report"]
27+
28+
[dependency-groups]
29+
dev = [
30+
"pre-commit>=4.5.1",
31+
]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
nerc-rates>=1.0.1,<2.0.0
22
pandas
33
pyarrow
4+
pyiceberg[pyarrow]>=0.11.0
45
boto3>=1.42.6,<2.0
56
Jinja2
67
validators

0 commit comments

Comments
 (0)