Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
38 changes: 38 additions & 0 deletions process_report/data_tools/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pydantic_settings import BaseSettings


class DataToolsSettings(BaseSettings):
"""Iceberg warehouse path and S3 credentials for data_tools queries."""

iceberg_warehouse_base: str = "s3://nerc-invoicing-iceberg/warehouse"
iceberg_table_subpath: str = "nerc_invoicing_iceberg/nerc_invoicing_iceberg"
iceberg_s3_access_key_id: str | None = None
iceberg_s3_secret_access_key: str | None = None
iceberg_s3_endpoint: str | None = None
iceberg_s3_region: str = "us-east-005"

@property
def table_path(self) -> str:
return f"{self.iceberg_warehouse_base}/{self.iceberg_table_subpath}"

def iceberg_s3_properties(self) -> dict[str, str]:
if not all(
[
self.iceberg_s3_access_key_id,
self.iceberg_s3_secret_access_key,
self.iceberg_s3_endpoint,
]
):
raise ValueError(
"Iceberg S3 credentials required: "
"ICEBERG_S3_ACCESS_KEY_ID, ICEBERG_S3_SECRET_ACCESS_KEY, ICEBERG_S3_ENDPOINT"
)
return {
"s3.access-key-id": self.iceberg_s3_access_key_id,
"s3.secret-access-key": self.iceberg_s3_secret_access_key,
"s3.endpoint": self.iceberg_s3_endpoint,
"s3.region": self.iceberg_s3_region,
}


data_tools_settings = DataToolsSettings()
178 changes: 178 additions & 0 deletions process_report/data_tools/costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import functools
import logging

import pandas as pd
import pyarrow
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, EqualTo
from pyiceberg.table import StaticTable

import process_report.invoices.invoice as invoice
from process_report.data_tools.config import data_tools_settings

logger = logging.getLogger(__name__)
FilterValue = str | int | float

_LIFETIME_COLS = [
invoice.PROJECT_ID_FIELD,
invoice.CLUSTER_NAME_FIELD,
invoice.COST_FIELD,
]


def _row_filter(**filters: FilterValue) -> BooleanExpression:
"""Combine column equality checks into a single PyIceberg filter expression.

Each keyword argument becomes one equality check (column == value).
Multiple checks are joined with AND.

Args:
**filters: Column names as keys, values to filter by. Values must be str, int, or float.

Returns:
A PyIceberg BooleanExpression combining all checks, or None if no filters were given.
"""
expression: BooleanExpression = AlwaysTrue()
for col, val in filters.items():
expression = expression & EqualTo(col, val)
return expression


@functools.cache
def get_table() -> StaticTable:
return StaticTable.from_metadata(
data_tools_settings.table_path,
properties=data_tools_settings.iceberg_s3_properties(),
)


@functools.cache
def get_invoice_dataframe(
cols: tuple[str, ...] | None = None, **filters: FilterValue
) -> pd.DataFrame:
"""Load invoice data from the Iceberg table.

Args:
cols: Column names to select as a tuple. Defaults to selects all columns.
**filters: Column names as keys, values to filter by. Values must be str, int, or float.

Returns:
DataFrame of invoice data from the table.
"""
table = get_table()
scan = table.scan(row_filter=_row_filter(**filters))
if cols:
scan = scan.select(*cols)
df = scan.to_pandas()
if filters and df.empty:
logger.warning("No invoice rows matched filters: %s", filters)
return df
Comment thread
QuanMPhm marked this conversation as resolved.


def group_and_sum(
df: pd.DataFrame,
group_by: tuple[str, ...],
*,
agg_col: str,
agg_name: str = "total",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knikolla Is there a strong need for the output column to have a custom name? This function already only outputs the grouped and output columns, so without agg_name, there would not be any ambiguity.

a_summed = a.groupby([INVOICE_DATE_FIELD, PROJECT_FIELD], dropna=False, as_index=False)["Balance"].agg("sum")
print(a_summed.head())

  Invoice Month Project - Allocation  Balance
0       2024-03                False     6.43
1       2024-03                 True   403.30
2       2024-03                  NaN     3.30
3       2024-04                False  1000.00
4       2024-04                 True   510.00

I just don't want unneeded features for now.

) -> pd.DataFrame:
"""Group a dataframe and aggregate one column with sum.

Args:
df: Input dataframe.
group_by: Column names to group by.
agg_col: Column to sum.
agg_name: Name for the aggregated column in the output. Defaults to "total".

Returns:
DataFrame with one row per group and a column containing the sum of agg_col.

Raises:
ValueError: If agg_col is not present in df.
TypeError: If unable to cast agg_col to decimal128
"""
if agg_col not in df.columns:
raise ValueError(
f"Aggregation column '{agg_col}' not found in dataframe. "
f"Available columns: {list(df.columns)}"
)

decimal_dtype = pd.ArrowDtype(pyarrow.decimal128(21, 2))
grouped_input = df.copy()

try:
grouped_input[agg_col] = grouped_input[agg_col].fillna(0).astype(decimal_dtype)
except pyarrow.ArrowException as e:
raise TypeError(f"Unable to cast column {agg_col} to decimal: {e}")

agg_spec = {agg_name: (agg_col, "sum")}
return grouped_input.groupby(list(group_by), as_index=False).agg(**agg_spec)


def aggregate_by(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split the previous select_and_group into two functions aggregate_by and group_and_sum. I think the old function was doing too much and was misleading. This way group_and_sum alone is useful whenever you already have a DataFrame and just need deterministic aggregation.

cols: tuple[str, ...],
group_by: tuple[str, ...],
*,
agg_col: str,
agg_name: str = "total",
**filters: FilterValue,
) -> pd.DataFrame:
"""Load invoice data and return grouped sum totals.

This helper fetches invoice rows using the provided selected columns and filters,
ensures grouping columns are included in the selection, then returns a grouped sum
aggregation over ``agg_col``.

Args:
cols: Columns to select from the invoice table before aggregation.
group_by: Columns to group rows by in the aggregated output.
agg_col: Numeric column to sum within each group.
agg_name: Output column name for the aggregated sum. Defaults to ``"total"``.
**filters: Column=value equality filters applied while loading invoice data.
Values must be str, int, or float.

Returns:
DataFrame with one row per unique ``group_by`` combination and a summed
``agg_name`` column quantized to two decimal places.

Example:
>>> df = aggregate_by(
... cols=(invoice.COST_FIELD,),
... group_by=(invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD),
... agg_col=invoice.COST_FIELD,
... agg_name="lifetime_allocation_cost",
... )
"""
all_cols = list(cols)
for col in group_by:
if col not in all_cols:
all_cols.append(col)
df = get_invoice_dataframe(tuple(all_cols), **filters)
return group_and_sum(
df,
group_by=group_by,
agg_col=agg_col,
agg_name=agg_name,
)


def calculate_lifetime_costs(**filters: FilterValue) -> pd.DataFrame:
"""Group invoice data by project and cluster, summing the COST column per group.

Args:
**filters: Column names as keys, values to filter by. Values must be str, int, or float.

Returns:
DataFrame with columns: Project - Allocation, Cluster Name, lifetime_allocation_cost.

Example:
>>> filters = {invoice.PROJECT_ID_FIELD: "vllm-test"}
>>> df = calculate_lifetime_costs(**filters)
"""

return aggregate_by(
tuple(_LIFETIME_COLS),
(invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD),
agg_col=invoice.COST_FIELD,
agg_name="lifetime_allocation_cost",
**filters,
)
164 changes: 164 additions & 0 deletions process_report/tests/unit/data_tools/test_data_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from decimal import Decimal
from unittest import mock

import pandas as pd
import pyarrow
import pytest
from pyiceberg.expressions import AlwaysTrue, EqualTo, BooleanExpression

from process_report.data_tools import costs

# These are the column names in the iceberg table using string literals instead of the invoice module to test column name correctness
PID = "Project - Allocation ID"
CLUSTER = "Cluster Name"
COST = "Cost"


@pytest.fixture(autouse=True)
def clear_dataframe_cache():
costs.get_invoice_dataframe.cache_clear()
yield
costs.get_invoice_dataframe.cache_clear()


@pytest.fixture
def sample_invoice_dataframe() -> pd.DataFrame:
return pd.DataFrame(
{
PID: ["vllm-test", "vllm-test", "webrca-1b021a"],
CLUSTER: ["ocp-test", "ocp-test", "ocp-prod"],
COST: [1.234, 2.345, None],
}
)


def test_row_filter_empty():
assert costs._row_filter() is AlwaysTrue()


@pytest.mark.parametrize(
"filters",
[
{PID: "vllm-test", CLUSTER: "ocp-test"},
{PID: "vllm-test", CLUSTER: "ocp-prod"},
],
)
def test_row_filter_builds_combined_and_expression(filters: dict[str, str]):
expression = costs._row_filter(**filters)
assert isinstance(expression, BooleanExpression)

(left_col, left_val), (right_col, right_val) = filters.items()
assert expression == AlwaysTrue() & EqualTo(left_col, left_val) & EqualTo(
right_col, right_val
)


def test_aggregate_by_rounds_and_forwards_filters(
monkeypatch: pytest.MonkeyPatch, sample_invoice_dataframe: pd.DataFrame
):
mock_loader = mock.MagicMock(return_value=sample_invoice_dataframe)
monkeypatch.setattr(costs, "get_invoice_dataframe", mock_loader)

result = costs.aggregate_by(
(COST,),
(PID, CLUSTER),
agg_col=COST,
agg_name="lifetime_allocation_cost",
**{PID: "vllm-test"},
)

args, kwargs = mock_loader.call_args
assert args == ((COST, PID, CLUSTER),)
assert kwargs == {PID: "vllm-test"}

decimal_dtype = pd.ArrowDtype(pyarrow.decimal128(21, 2))
expected = pd.DataFrame(
{
PID: ["vllm-test", "webrca-1b021a"],
CLUSTER: ["ocp-test", "ocp-prod"],
"lifetime_allocation_cost": pd.array(
[Decimal("3.58"), Decimal("0.00")], dtype=decimal_dtype
),
}
)
assert result.equals(expected)


def test_group_and_sum_raises_on_missing_column(sample_invoice_dataframe: pd.DataFrame):
with pytest.raises(ValueError, match="not found in dataframe"):
costs.group_and_sum(
sample_invoice_dataframe,
(PID, CLUSTER),
agg_col="non_existent_column",
agg_name="lifetime_allocation_cost",
)


def test_group_and_sum_raises_on_cast_error(
sample_invoice_dataframe: pd.DataFrame,
):
with pytest.raises(TypeError, match="Unable to cast column"):
costs.group_and_sum(
sample_invoice_dataframe,
(CLUSTER,),
agg_col=PID,
agg_name="lifetime_allocation_cost",
)


@pytest.mark.parametrize(
"invalid_filters",
[
{PID: "does-not-exist"},
{CLUSTER: "not-a-real-cluster"},
{PID: "does-not-exist", CLUSTER: "not-a-real-cluster"},
],
)
def test_calculate_lifetime_costs_invalid_queries_return_empty(
monkeypatch: pytest.MonkeyPatch, invalid_filters: dict[str, str]
):
empty_df = pd.DataFrame(columns=[PID, CLUSTER, COST])
monkeypatch.setattr(costs, "get_invoice_dataframe", lambda cols=None, **f: empty_df)

result = costs.calculate_lifetime_costs(**invalid_filters)
assert result.empty


class _FakeIcebergTable:
"""Responds to .scan().select().to_pandas() chains."""

def __init__(self, df: pd.DataFrame):
self._df = df

def scan(self, row_filter=None):
return self

def select(self, *cols):
return self

def to_pandas(self):
return self._df


def test_get_invoice_dataframe_warns_when_no_rows_match(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
):
table = _FakeIcebergTable(pd.DataFrame(columns=[PID, COST]))
monkeypatch.setattr(costs, "get_table", lambda: table)

with caplog.at_level("WARNING", logger=costs.__name__):
costs.get_invoice_dataframe((PID, COST), **{PID: "does-not-exist"})

assert "No invoice rows matched filters" in caplog.text


def test_get_invoice_dataframe_caches_repeated_query(monkeypatch: pytest.MonkeyPatch):
table = _FakeIcebergTable(pd.DataFrame({PID: ["vllm-test"], COST: [1.0]}))
mock_get_table = mock.MagicMock(return_value=table)
monkeypatch.setattr(costs, "get_table", mock_get_table)

first = costs.get_invoice_dataframe((PID, COST), **{PID: "vllm-test"})
second = costs.get_invoice_dataframe((PID, COST), **{PID: "vllm-test"})

assert mock_get_table.call_count == 1
assert first is second
Loading
Loading