Skip to content

Use polars to read design matrix #10047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [
"matplotlib",
"netCDF4",
"numpy<2",
"openpyxl", # extra dependency for pandas (excel)
"opentelemetry-api<1.30.0",
"opentelemetry-sdk<1.30.0",
"opentelemetry-instrumentation-fastapi<0.51b0",
Expand All @@ -49,24 +48,26 @@ dependencies = [
"pluggy>=1.3.0",
"polars>=1",
"psutil",
"pyarrow", # extra dependency for pandas (parquet)
"pyarrow", # extra dependency for pandas (parquet)
"pydantic > 2",
"python-dateutil",
"python-multipart", # extra dependency for fastapi
"python-multipart", # extra dependency for fastapi
"pyyaml",
"pyzmq",
"pyqt6",
"requests",
"resfo",
"scipy >= 1.10.1, < 1.15",
"seaborn",
"tables", # extra dependency for pandas (hdf5)
"tables", # extra dependency for pandas (hdf5)
"tabulate",
"tqdm>=4.62.0",
"typing_extensions>=4.5",
"uvicorn >= 0.17.0",
"xarray",
"xtgeo >= 3.3.0",
"fastexcel>=0.12.1",
"xlsxwriter>=3.2.2",
]

[project.scripts]
Expand Down
205 changes: 113 additions & 92 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from pandas.api.types import is_integer_dtype
import polars as pl

from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition

Expand All @@ -18,8 +18,6 @@

DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"

from ert.shared.status.utils import convert_to_numeric


@dataclass
class DesignMatrix:
Expand Down Expand Up @@ -87,27 +85,32 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
ErrorInfo("Design Matrices don't have the same active realizations!")
)

common_keys = set(self.design_matrix_df.columns) & set(
dm_other.design_matrix_df.columns
)
common_keys = set(
self.design_matrix_df.select(pl.exclude("REAL")).columns
) & set(dm_other.design_matrix_df.select(pl.exclude("REAL")).columns)
if common_keys:
errors.append(
ErrorInfo(f"Design Matrices do not have unique keys {common_keys}!")
)

if errors:
raise ConfigValidationError.from_collected(errors)
try:
self.design_matrix_df = pd.concat(
[self.design_matrix_df, dm_other.design_matrix_df], axis=1
self.design_matrix_df = pl.concat(
[
self.design_matrix_df,
dm_other.design_matrix_df.select(pl.exclude("REAL")),
],
how="horizontal",
)
except ValueError as exc:
errors.append(ErrorInfo(f"Error when merging design matrices {exc}!"))
raise ConfigValidationError.from_info(
ErrorInfo(f"Error when merging design matrices {exc}!")
) from exc

for tfd in dm_other.parameter_configuration.transform_function_definitions:
self.parameter_configuration.transform_function_definitions.append(tfd)

if errors:
raise ConfigValidationError.from_collected(errors)

def merge_with_existing_parameters(
self, existing_parameters: list[ParameterConfig]
) -> tuple[list[ParameterConfig], GenKwConfig]:
Expand Down Expand Up @@ -159,49 +162,65 @@ def merge_with_existing_parameters(

def read_design_matrix(
self,
) -> tuple[list[bool], pd.DataFrame, GenKwConfig]:
) -> tuple[list[bool], pl.DataFrame, GenKwConfig]:
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
param_names = (
pd.read_excel(
self.xls_filename,
sheet_name=self.design_sheet,
nrows=1,
header=None,
dtype="string",
)
.iloc[0]
.apply(lambda x: x.strip() if isinstance(x, str) else x)
)
design_matrix_df = DesignMatrix._read_excel(
self.xls_filename,
self.design_sheet,
header=None,
skiprows=1,
)
design_matrix_df.columns = param_names.to_list()

if "REAL" in design_matrix_df.columns:
if not is_integer_dtype(design_matrix_df.dtypes["REAL"]) or any(
design_matrix_df["REAL"] < 0
):
raise ValueError("REAL column must only contain positive integers")
design_matrix_df = design_matrix_df.set_index(
"REAL", drop=True, verify_integrity=True
try:
param_names = (
pl.read_excel(
self.xls_filename,
sheet_name=self.design_sheet,
has_header=False,
read_options={"n_rows": 1, "dtypes": "string"},
)
.select(pl.all().str.strip_chars())
.row(0)
)
except pl.exceptions.NoDataError as err:
raise ValueError("Design sheet is empty.") from err

if error_list := DesignMatrix._validate_design_matrix(design_matrix_df):
design_matrix_df = pl.read_excel(
self.xls_filename,
sheet_name=self.design_sheet,
has_header=False,
drop_empty_cols=True,
drop_empty_rows=True,
raise_if_empty=False,
infer_schema_length=None,
read_options={"skip_rows": 1},
).with_columns(pl.col(pl.Float32, pl.Float64).fill_nan(None))

if error_list := DesignMatrix._validate_design_matrix(
design_matrix_df, param_names
):
error_msg = "\n".join(error_list)
raise ValueError(f"Design matrix is not valid, error(s):\n{error_msg}")
design_matrix_df.columns = list(param_names)
if "REAL" in design_matrix_df.schema:
real_dt = design_matrix_df.schema.get("REAL")
assert real_dt is not None
if (
not real_dt.is_integer()
or design_matrix_df.get_column("REAL").lt(0).any()
or design_matrix_df.get_column("REAL").is_duplicated().any()
):
raise ValueError(
"REAL column must only contain unique positive integers"
)

else:
design_matrix_df = design_matrix_df.with_row_index(name="REAL")

defaults_to_use = DesignMatrix._read_defaultssheet(
self.xls_filename, self.default_sheet, design_matrix_df.columns.to_list()
self.xls_filename, self.default_sheet, design_matrix_df.columns
)
design_matrix_df = design_matrix_df.with_columns(
pl.lit(value).alias(name) for name, value in defaults_to_use.items()
)
design_matrix_df = design_matrix_df.assign(**defaults_to_use)

transform_function_definitions: list[TransformFunctionDefinition] = []
for parameter in design_matrix_df.columns:
for parameter in design_matrix_df.select(pl.exclude("REAL")).columns:
transform_function_definitions.append(
TransformFunctionDefinition(
name=parameter,
Expand All @@ -218,60 +237,45 @@ def read_design_matrix(
update=False,
)

design_matrix_df.columns = pd.MultiIndex.from_product(
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
)
reals = design_matrix_df.index.tolist()
reals = design_matrix_df.get_column("REAL").to_list()
return (
[x in reals for x in range(max(reals) + 1)],
design_matrix_df,
parameter_configuration,
)

@staticmethod
def _read_excel(
file_name: Path | str,
sheet_name: str,
usecols: list[int] | None = None,
header: int | None = 0,
skiprows: int | None = None,
dtype: str | None = None,
) -> pd.DataFrame:
"""
Reads an Excel file into a DataFrame, with options to filter columns and rows,
and automatically drops columns that contain only NaN values.
"""
df = pd.read_excel(
io=file_name,
sheet_name=sheet_name,
usecols=usecols,
header=header,
skiprows=skiprows,
dtype=dtype,
)
return df.dropna(axis=1, how="all")

@staticmethod
def _validate_design_matrix(design_matrix: pd.DataFrame) -> list[str]:
def _validate_design_matrix(
design_matrix: pl.DataFrame, column_names: tuple[str]
) -> list[str]:
"""
Validate user inputted design matrix
:raises: ValueError if design matrix contains empty headers or empty cells
"""
if design_matrix.empty:
if design_matrix.is_empty():
return []
errors = []
column_na_mask = design_matrix.columns.isna()
if not design_matrix.columns[~column_na_mask].is_unique:
errors.append("Duplicate parameter names found in design sheet")
param_name_count = Counter(p for p in column_names if p is not None)
duplicate_param_names = [(n, c) for n, c in param_name_count.items() if c > 1]
if duplicate_param_names:
duplicates_formatted = ", ".join(
f"{name}({count})" for name, count in duplicate_param_names
)
errors.append(
f"Duplicate parameter names found in design sheet: {duplicates_formatted}"
)
empties = [
f"Realization {design_matrix.index[i]}, column {design_matrix.columns[j]}"
for i, j in zip(*np.where(pd.isna(design_matrix)), strict=False)
f"Row {i}, column {j}"
for i, j in zip(
*np.where(design_matrix.select(pl.all().is_null())),
strict=False,
)
]
if len(empties) > 0:
errors.append(f"Design matrix contains empty cells {empties}")

for column_num, param_name in enumerate(design_matrix.columns):
if pd.isna(param_name) or len(param_name.split()) == 0:
for column_num, param_name in enumerate(column_names):
if param_name is None or len(param_name.split()) == 0:
errors.append(f"Empty parameter name found in column {column_num}.")
elif len(param_name.split()) > 1:
errors.append(
Expand All @@ -283,10 +287,10 @@ def _validate_design_matrix(design_matrix: pd.DataFrame) -> list[str]:

@staticmethod
def _read_defaultssheet(
xls_filename: Path | str,
xls_filename: Path,
defaults_sheetname: str,
existing_parameters: list[str],
) -> dict[str, str | float]:
) -> dict[str, str | float | int]:
"""
Construct a dict of keys and values to be used as defaults from the
first two columns in a spreadsheet. Only returns the keys that are
Expand All @@ -296,28 +300,45 @@ def _read_defaultssheet(

:raises: ValueError if defaults sheet is non-empty but non-parsable
"""
default_df = DesignMatrix._read_excel(

default_df = pl.read_excel(
xls_filename,
defaults_sheetname,
header=None,
dtype="string",
sheet_name=defaults_sheetname,
has_header=False,
drop_empty_cols=True,
drop_empty_rows=True,
raise_if_empty=False,
read_options={"dtypes": "string"},
)
if default_df.empty:
if default_df.is_empty():
return {}
if len(default_df.columns) < 2:
raise ValueError("Defaults sheet must have at least two columns")
empty_cells = [
f"Row {default_df.index[i]}, column {default_df.columns[j]}"
for i, j in zip(*np.where(pd.isna(default_df)), strict=False)
f"Row {i}, column {j}"
for i, j in zip(
*np.where(default_df.select(pl.all().is_null())), strict=False
)
]
if len(empty_cells) > 0:
raise ValueError(f"Default sheet contains empty cells {empty_cells}")
default_df[0] = default_df[0].apply(lambda x: x.strip())
if not default_df[0].is_unique:
default_df = default_df.with_columns(pl.nth(0).str.strip_chars())
if default_df.select(pl.nth(0)).is_duplicated().any():
raise ValueError("Default sheet contains duplicate parameter names")

return {
row[0]: convert_to_numeric(row[1])
for _, row in default_df.iterrows()
for row in default_df.iter_rows()
if row[0] not in existing_parameters
}


def convert_to_numeric(x: str) -> str | float | int:
try:
return int(x)
except ValueError:
try:
return float(x)

except ValueError:
return x
22 changes: 13 additions & 9 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Any

import orjson
import pandas as pd
import polars as pl
import xarray as xr
from numpy.random import SeedSequence

Expand Down Expand Up @@ -157,24 +157,28 @@ def _seed_sequence(seed: int | None) -> int:


def save_design_matrix_to_ensemble(
design_matrix_df: pd.DataFrame,
design_matrix_df: pl.DataFrame,
ensemble: Ensemble,
active_realizations: Iterable[int],
design_group_name: str = DESIGN_MATRIX_GROUP,
) -> None:
assert not design_matrix_df.empty
for realization_nr in active_realizations:
row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP]
assert not design_matrix_df.is_empty()
param_names = design_matrix_df.select(pl.exclude("REAL")).columns
if not set(design_matrix_df.get_column("REAL")).issubset(set(active_realizations)):
raise KeyError("Active realizations not found in design matrix data frame.")
for row in design_matrix_df.filter(
pl.col("REAL").is_in(pl.Series(active_realizations))
).to_numpy():
ds = xr.Dataset(
{
"values": ("names", list(row.values)),
"transformed_values": ("names", list(row.values)),
"names": list(row.keys()),
"values": ("names", list(row[1:])),
"transformed_values": ("names", list(row[1:])),
"names": param_names,
}
)
ensemble.save_parameters(
design_group_name,
realization_nr,
row[0],
ds,
)

Expand Down
Loading
Loading