-
Notifications
You must be signed in to change notification settings - Fork 6
Added nerc_invoicing repo with functions for querying invoicing Iceberg table #267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from pydantic_settings import BaseSettings | ||
|
|
||
|
|
||
| class DataToolsSettings(BaseSettings): | ||
| """Iceberg warehouse path and S3 credentials for data_tools queries.""" | ||
|
|
||
| iceberg_warehouse_base: str = "s3://nerc-invoicing-iceberg/warehouse" | ||
| iceberg_table_subpath: str = "nerc_invoicing_iceberg/nerc_invoicing_iceberg" | ||
| iceberg_s3_access_key_id: str | None = None | ||
| iceberg_s3_secret_access_key: str | None = None | ||
| iceberg_s3_endpoint: str | None = None | ||
| iceberg_s3_region: str = "us-east-005" | ||
|
|
||
| @property | ||
| def table_path(self) -> str: | ||
| return f"{self.iceberg_warehouse_base}/{self.iceberg_table_subpath}" | ||
|
|
||
| def iceberg_s3_properties(self) -> dict[str, str]: | ||
| if not all( | ||
| [ | ||
| self.iceberg_s3_access_key_id, | ||
| self.iceberg_s3_secret_access_key, | ||
| self.iceberg_s3_endpoint, | ||
| ] | ||
| ): | ||
| raise ValueError( | ||
| "Iceberg S3 credentials required: " | ||
| "ICEBERG_S3_ACCESS_KEY_ID, ICEBERG_S3_SECRET_ACCESS_KEY, ICEBERG_S3_ENDPOINT" | ||
| ) | ||
| return { | ||
| "s3.access-key-id": self.iceberg_s3_access_key_id, | ||
| "s3.secret-access-key": self.iceberg_s3_secret_access_key, | ||
| "s3.endpoint": self.iceberg_s3_endpoint, | ||
| "s3.region": self.iceberg_s3_region, | ||
| } | ||
|
|
||
|
|
||
| data_tools_settings = DataToolsSettings() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| import functools | ||
| import logging | ||
|
|
||
| import pandas as pd | ||
| import pyarrow | ||
| from pyiceberg.expressions import AlwaysTrue, BooleanExpression, EqualTo | ||
| from pyiceberg.table import StaticTable | ||
|
|
||
| import process_report.invoices.invoice as invoice | ||
| from process_report.data_tools.config import data_tools_settings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| FilterValue = str | int | float | ||
|
|
||
| _LIFETIME_COLS = [ | ||
| invoice.PROJECT_ID_FIELD, | ||
| invoice.CLUSTER_NAME_FIELD, | ||
| invoice.COST_FIELD, | ||
| ] | ||
|
|
||
|
|
||
| def _row_filter(**filters: FilterValue) -> BooleanExpression: | ||
| """Combine column equality checks into a single PyIceberg filter expression. | ||
|
|
||
| Each keyword argument becomes one equality check (column == value). | ||
| Multiple checks are joined with AND. | ||
|
|
||
| Args: | ||
| **filters: Column names as keys, values to filter by. Values must be str, int, or float. | ||
|
|
||
| Returns: | ||
| A PyIceberg BooleanExpression combining all checks, or None if no filters were given. | ||
| """ | ||
| expression: BooleanExpression = AlwaysTrue() | ||
| for col, val in filters.items(): | ||
| expression = expression & EqualTo(col, val) | ||
| return expression | ||
|
|
||
|
|
||
| @functools.cache | ||
| def get_table() -> StaticTable: | ||
| return StaticTable.from_metadata( | ||
| data_tools_settings.table_path, | ||
| properties=data_tools_settings.iceberg_s3_properties(), | ||
| ) | ||
|
|
||
|
|
||
| @functools.cache | ||
| def get_invoice_dataframe( | ||
| cols: tuple[str, ...] | None = None, **filters: FilterValue | ||
| ) -> pd.DataFrame: | ||
| """Load invoice data from the Iceberg table. | ||
|
|
||
| Args: | ||
| cols: Column names to select as a tuple. Defaults to selects all columns. | ||
| **filters: Column names as keys, values to filter by. Values must be str, int, or float. | ||
|
|
||
| Returns: | ||
| DataFrame of invoice data from the table. | ||
| """ | ||
| table = get_table() | ||
| scan = table.scan(row_filter=_row_filter(**filters)) | ||
| if cols: | ||
| scan = scan.select(*cols) | ||
| df = scan.to_pandas() | ||
| if filters and df.empty: | ||
| logger.warning("No invoice rows matched filters: %s", filters) | ||
| return df | ||
|
|
||
|
|
||
| def group_and_sum( | ||
| df: pd.DataFrame, | ||
| group_by: tuple[str, ...], | ||
| *, | ||
| agg_col: str, | ||
| agg_name: str = "total", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @knikolla Is there a strong need for the output column to have a custom name? This function already only outputs the grouped and output columns, so without a_summed = a.groupby([INVOICE_DATE_FIELD, PROJECT_FIELD], dropna=False, as_index=False)["Balance"].agg("sum")
print(a_summed.head())
Invoice Month Project - Allocation Balance
0 2024-03 False 6.43
1 2024-03 True 403.30
2 2024-03 NaN 3.30
3 2024-04 False 1000.00
4 2024-04 True 510.00I just don't want unneeded features for now. |
||
| ) -> pd.DataFrame: | ||
| """Group a dataframe and aggregate one column with sum. | ||
|
|
||
| Args: | ||
| df: Input dataframe. | ||
| group_by: Column names to group by. | ||
| agg_col: Column to sum. | ||
| agg_name: Name for the aggregated column in the output. Defaults to "total". | ||
|
|
||
| Returns: | ||
| DataFrame with one row per group and a column containing the sum of agg_col. | ||
|
|
||
| Raises: | ||
| ValueError: If agg_col is not present in df. | ||
| TypeError: If unable to cast agg_col to decimal128 | ||
| """ | ||
| if agg_col not in df.columns: | ||
| raise ValueError( | ||
| f"Aggregation column '{agg_col}' not found in dataframe. " | ||
| f"Available columns: {list(df.columns)}" | ||
| ) | ||
|
|
||
| decimal_dtype = pd.ArrowDtype(pyarrow.decimal128(21, 2)) | ||
| grouped_input = df.copy() | ||
|
|
||
| try: | ||
| grouped_input[agg_col] = grouped_input[agg_col].fillna(0).astype(decimal_dtype) | ||
| except pyarrow.ArrowException as e: | ||
| raise TypeError(f"Unable to cast column {agg_col} to decimal: {e}") | ||
|
|
||
| agg_spec = {agg_name: (agg_col, "sum")} | ||
| return grouped_input.groupby(list(group_by), as_index=False).agg(**agg_spec) | ||
|
|
||
|
|
||
| def aggregate_by( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've split the previous |
||
| cols: tuple[str, ...], | ||
| group_by: tuple[str, ...], | ||
| *, | ||
| agg_col: str, | ||
| agg_name: str = "total", | ||
| **filters: FilterValue, | ||
| ) -> pd.DataFrame: | ||
| """Load invoice data and return grouped sum totals. | ||
|
|
||
| This helper fetches invoice rows using the provided selected columns and filters, | ||
| ensures grouping columns are included in the selection, then returns a grouped sum | ||
| aggregation over ``agg_col``. | ||
|
|
||
| Args: | ||
| cols: Columns to select from the invoice table before aggregation. | ||
| group_by: Columns to group rows by in the aggregated output. | ||
| agg_col: Numeric column to sum within each group. | ||
| agg_name: Output column name for the aggregated sum. Defaults to ``"total"``. | ||
| **filters: Column=value equality filters applied while loading invoice data. | ||
| Values must be str, int, or float. | ||
|
|
||
| Returns: | ||
| DataFrame with one row per unique ``group_by`` combination and a summed | ||
| ``agg_name`` column quantized to two decimal places. | ||
|
|
||
| Example: | ||
| >>> df = aggregate_by( | ||
| ... cols=(invoice.COST_FIELD,), | ||
| ... group_by=(invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD), | ||
| ... agg_col=invoice.COST_FIELD, | ||
| ... agg_name="lifetime_allocation_cost", | ||
| ... ) | ||
| """ | ||
| all_cols = list(cols) | ||
| for col in group_by: | ||
| if col not in all_cols: | ||
| all_cols.append(col) | ||
| df = get_invoice_dataframe(tuple(all_cols), **filters) | ||
| return group_and_sum( | ||
| df, | ||
| group_by=group_by, | ||
| agg_col=agg_col, | ||
| agg_name=agg_name, | ||
| ) | ||
|
|
||
|
|
||
| def calculate_lifetime_costs(**filters: FilterValue) -> pd.DataFrame: | ||
| """Group invoice data by project and cluster, summing the COST column per group. | ||
|
|
||
| Args: | ||
| **filters: Column names as keys, values to filter by. Values must be str, int, or float. | ||
|
|
||
| Returns: | ||
| DataFrame with columns: Project - Allocation, Cluster Name, lifetime_allocation_cost. | ||
|
|
||
| Example: | ||
| >>> filters = {invoice.PROJECT_ID_FIELD: "vllm-test"} | ||
| >>> df = calculate_lifetime_costs(**filters) | ||
| """ | ||
|
|
||
| return aggregate_by( | ||
| tuple(_LIFETIME_COLS), | ||
| (invoice.PROJECT_ID_FIELD, invoice.CLUSTER_NAME_FIELD), | ||
| agg_col=invoice.COST_FIELD, | ||
| agg_name="lifetime_allocation_cost", | ||
| **filters, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| from decimal import Decimal | ||
| from unittest import mock | ||
|
|
||
| import pandas as pd | ||
| import pyarrow | ||
| import pytest | ||
| from pyiceberg.expressions import AlwaysTrue, EqualTo, BooleanExpression | ||
|
|
||
| from process_report.data_tools import costs | ||
|
|
||
| # These are the column names in the iceberg table using string literals instead of the invoice module to test column name correctness | ||
| PID = "Project - Allocation ID" | ||
| CLUSTER = "Cluster Name" | ||
| COST = "Cost" | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def clear_dataframe_cache(): | ||
| costs.get_invoice_dataframe.cache_clear() | ||
| yield | ||
| costs.get_invoice_dataframe.cache_clear() | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_invoice_dataframe() -> pd.DataFrame: | ||
| return pd.DataFrame( | ||
| { | ||
| PID: ["vllm-test", "vllm-test", "webrca-1b021a"], | ||
| CLUSTER: ["ocp-test", "ocp-test", "ocp-prod"], | ||
| COST: [1.234, 2.345, None], | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def test_row_filter_empty(): | ||
| assert costs._row_filter() is AlwaysTrue() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "filters", | ||
| [ | ||
| {PID: "vllm-test", CLUSTER: "ocp-test"}, | ||
| {PID: "vllm-test", CLUSTER: "ocp-prod"}, | ||
| ], | ||
| ) | ||
| def test_row_filter_builds_combined_and_expression(filters: dict[str, str]): | ||
| expression = costs._row_filter(**filters) | ||
| assert isinstance(expression, BooleanExpression) | ||
|
|
||
| (left_col, left_val), (right_col, right_val) = filters.items() | ||
| assert expression == AlwaysTrue() & EqualTo(left_col, left_val) & EqualTo( | ||
| right_col, right_val | ||
| ) | ||
|
|
||
|
|
||
| def test_aggregate_by_rounds_and_forwards_filters( | ||
| monkeypatch: pytest.MonkeyPatch, sample_invoice_dataframe: pd.DataFrame | ||
| ): | ||
| mock_loader = mock.MagicMock(return_value=sample_invoice_dataframe) | ||
| monkeypatch.setattr(costs, "get_invoice_dataframe", mock_loader) | ||
|
|
||
| result = costs.aggregate_by( | ||
| (COST,), | ||
| (PID, CLUSTER), | ||
| agg_col=COST, | ||
| agg_name="lifetime_allocation_cost", | ||
| **{PID: "vllm-test"}, | ||
| ) | ||
|
|
||
| args, kwargs = mock_loader.call_args | ||
| assert args == ((COST, PID, CLUSTER),) | ||
| assert kwargs == {PID: "vllm-test"} | ||
|
|
||
| decimal_dtype = pd.ArrowDtype(pyarrow.decimal128(21, 2)) | ||
| expected = pd.DataFrame( | ||
| { | ||
| PID: ["vllm-test", "webrca-1b021a"], | ||
| CLUSTER: ["ocp-test", "ocp-prod"], | ||
| "lifetime_allocation_cost": pd.array( | ||
| [Decimal("3.58"), Decimal("0.00")], dtype=decimal_dtype | ||
| ), | ||
| } | ||
| ) | ||
| assert result.equals(expected) | ||
|
|
||
|
|
||
| def test_group_and_sum_raises_on_missing_column(sample_invoice_dataframe: pd.DataFrame): | ||
| with pytest.raises(ValueError, match="not found in dataframe"): | ||
| costs.group_and_sum( | ||
| sample_invoice_dataframe, | ||
| (PID, CLUSTER), | ||
| agg_col="non_existent_column", | ||
| agg_name="lifetime_allocation_cost", | ||
| ) | ||
|
|
||
|
|
||
| def test_group_and_sum_raises_on_cast_error( | ||
| sample_invoice_dataframe: pd.DataFrame, | ||
| ): | ||
| with pytest.raises(TypeError, match="Unable to cast column"): | ||
| costs.group_and_sum( | ||
| sample_invoice_dataframe, | ||
| (CLUSTER,), | ||
| agg_col=PID, | ||
| agg_name="lifetime_allocation_cost", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "invalid_filters", | ||
| [ | ||
| {PID: "does-not-exist"}, | ||
| {CLUSTER: "not-a-real-cluster"}, | ||
| {PID: "does-not-exist", CLUSTER: "not-a-real-cluster"}, | ||
| ], | ||
| ) | ||
| def test_calculate_lifetime_costs_invalid_queries_return_empty( | ||
| monkeypatch: pytest.MonkeyPatch, invalid_filters: dict[str, str] | ||
| ): | ||
| empty_df = pd.DataFrame(columns=[PID, CLUSTER, COST]) | ||
| monkeypatch.setattr(costs, "get_invoice_dataframe", lambda cols=None, **f: empty_df) | ||
|
|
||
| result = costs.calculate_lifetime_costs(**invalid_filters) | ||
| assert result.empty | ||
|
|
||
|
|
||
| class _FakeIcebergTable: | ||
| """Responds to .scan().select().to_pandas() chains.""" | ||
|
|
||
| def __init__(self, df: pd.DataFrame): | ||
| self._df = df | ||
|
|
||
| def scan(self, row_filter=None): | ||
| return self | ||
|
|
||
| def select(self, *cols): | ||
| return self | ||
|
|
||
| def to_pandas(self): | ||
| return self._df | ||
|
|
||
|
|
||
| def test_get_invoice_dataframe_warns_when_no_rows_match( | ||
| monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture | ||
| ): | ||
| table = _FakeIcebergTable(pd.DataFrame(columns=[PID, COST])) | ||
| monkeypatch.setattr(costs, "get_table", lambda: table) | ||
|
|
||
| with caplog.at_level("WARNING", logger=costs.__name__): | ||
| costs.get_invoice_dataframe((PID, COST), **{PID: "does-not-exist"}) | ||
|
|
||
| assert "No invoice rows matched filters" in caplog.text | ||
|
|
||
|
|
||
| def test_get_invoice_dataframe_caches_repeated_query(monkeypatch: pytest.MonkeyPatch): | ||
| table = _FakeIcebergTable(pd.DataFrame({PID: ["vllm-test"], COST: [1.0]})) | ||
| mock_get_table = mock.MagicMock(return_value=table) | ||
| monkeypatch.setattr(costs, "get_table", mock_get_table) | ||
|
|
||
| first = costs.get_invoice_dataframe((PID, COST), **{PID: "vllm-test"}) | ||
| second = costs.get_invoice_dataframe((PID, COST), **{PID: "vllm-test"}) | ||
|
|
||
| assert mock_get_table.call_count == 1 | ||
| assert first is second |
Uh oh!
There was an error while loading. Please reload this page.