Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import yaml
import copy
from typing import Any, Dict, Optional
import logging

logger = logging.getLogger(__name__)


def load_yaml(path: str) -> Dict[str, Any]:
"""Load a YAML file and return its contents as a dictionary."""
with open(path, 'r') as f:
return yaml.safe_load(f)


def deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively merge `overrides` into `base`.
Values in `overrides` take precedence. Returns a new dict.

Example:
base = {"a": 1, "b": {"c": 2, "d": 3}}
overrides = {"b": {"c": 99}, "e": 5}
result = {"a": 1, "b": {"c": 99, "d": 3}, "e": 5}
"""
result = copy.deepcopy(base)
for key, value in overrides.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result


def resolve_layer_config(base_config_path: str, overrides: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Load a layer's base YAML config file and apply any experiment-level overrides.

Args:
base_config_path: Path to the layer's own config YAML.
overrides: Dictionary of keys to override from the experiment config.

Returns:
The merged configuration dictionary.
"""
base_config = load_yaml(base_config_path)
if overrides:
merged = deep_merge(base_config, overrides)
logger.info(f"Applied {len(overrides)} override(s) to {base_config_path}")
return merged
return base_config
25 changes: 17 additions & 8 deletions data_layer/data_module.py → data_layer/data_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
import yaml
from typing import Dict, Tuple, List, Any
from typing import Dict, Optional, Tuple, List, Any

class DataModule:
class DataObject:
"""
A unified data ingestion and preprocessing pipeline for algorithmic recourse tasks.

Expand All @@ -16,8 +16,8 @@ class DataModule:

NOTE: this module will essentially take the place of the existing data module and dataset classes,
and all the functionality in the loadData method will be transferred here as member functions.
The "get_preprocessing()" acts like a controller that, based on confgs, will call appropriate util
funtions. (think large if else block).
The "get_preprocessing()" acts like a controller that, based on configs, will call appropriate util
funtions. (think large if-else block).

The attributes and util member methods can be expanded on a method need bases.

Expand All @@ -32,19 +32,28 @@ class DataModule:
metadata (Dict[str, Any]): Generated bounds, constraints, and structural info for features.
"""

def __init__(self, data_path: str, config_path: str):
def __init__(self, data_path: str, config_path: str = None, config_override: Optional[Dict[str, Any]] = None):
"""
Initializes the DataModule by loading the raw data and configuration.
Initializes the DataObject by loading the raw data and configuration.

Args:
data_path (str): The file path to the raw CSV dataset.
config_path (str): The file path to the YAML configuration file.
config_override (Optional[Dict[str, Any]]): Optional dictionary of config overrides.
"""
self._metadata = {}
self._raw_df = pd.read_csv(data_path)
self._processed_df = self._raw_df.copy() # This will be transformed in place through the preprocessing pipeline.
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)

if config_path is not None:
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)
else:
self._config = {}

# If a pre-merged config is given, use it entirely (it already contains overrides)
if config_override is not None:
self._config = config_override

# drop columns not defined in the config
columns_to_drop = [col for col in self._raw_df.columns if col not in self._config['features'].keys()]
Expand Down
11 changes: 6 additions & 5 deletions evaluation_layer/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import pandas as pd

from evaluation_layer.evaluation_module import EvaluationModule
from evaluation_layer.evaluation_factory import register_evaluation
from evaluation_layer.evaluation_object import EvaluationObject
from evaluation_layer.utils import remove_nans
from data_layer.data_module import DataModule
from data_layer.data_object import DataObject


def l0_distance(delta: np.ndarray) -> List[float]:
Expand Down Expand Up @@ -146,13 +147,13 @@ def _get_distances(

return [[d1[i], d2[i], d3[i], d4[i]] for i in range(len(d1))]


class Distance(EvaluationModule):
@register_evaluation("Distance")
class Distance(EvaluationObject):
"""
Calculates the L0, L1, L2, and L-infty distance measures.
"""

def __init__(self, data: DataModule, hyperparameters: dict = None):
def __init__(self, data: DataObject, hyperparameters: dict = None):
super().__init__(data, hyperparameters)
self.columns = ["L0_distance", "L1_distance", "L2_distance", "Linf_distance"]

Expand Down
37 changes: 37 additions & 0 deletions evaluation_layer/evaluation_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from data_layer.data_object import DataObject
from evaluation_layer.evaluation_object import EvaluationObject
from typing import Dict, Any, List, Optional

_EVAL_REGISTRY = {}


def register_evaluation(name: str):
"""Decorator to register an evaluation metric class by name."""
def decorator(cls):
_EVAL_REGISTRY[name] = cls
return cls
return decorator


def create_evaluations(metrics_config: List[Dict[str, Any]],
data: DataObject) -> List[EvaluationObject]:
"""
Instantiate all requested evaluation modules from the experiment config.

Args:
metrics_config: List of dicts, each with "name" and optional "hyperparameters".
data: The DataObject instance.

Returns:
List of EvaluationObject instances.
"""
evaluations = []
for metric in metrics_config:
name = metric["name"]
hyperparams = metric.get("hyperparameters", None)
if name not in _EVAL_REGISTRY:
raise ValueError(
f"Evaluation '{name}' is not registered. Available: {list(_EVAL_REGISTRY.keys())}"
)
evaluations.append(_EVAL_REGISTRY[name](data, hyperparams))
return evaluations
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from abc import ABC, abstractmethod
import pandas as pd
from data_layer.data_module import DataModule
from data_layer.data_object import DataObject


class EvaluationModule(ABC):
def __init__(self, data: DataModule, hyperparameters: dict = None):
class EvaluationObject(ABC):
def __init__(self, data: DataObject, hyperparameters: dict = None):
"""

Parameters
----------
model:
Classification model. (optional)
data: DataObject
The data object containing the processed data and metadata.
hyperparameters:
Dictionary with hyperparameters, could be used to pass other things. (optional)
"""
Expand Down
10 changes: 5 additions & 5 deletions evaluation_layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pandas as pd
import numpy as np
from data_layer.data_module import DataModule
from model_layer.model_module import ModelModule
from data_layer.data_object import DataObject
from model_layer.model_object import ModelObject
import logging


def check_counterfactuals(model: ModelModule,
data: DataModule,
def check_counterfactuals(model: ModelObject,
data: DataObject,
counterfactuals: pd.DataFrame,
factual_indices: pd.Index) -> pd.DataFrame:
"""
Expand All @@ -19,7 +19,7 @@ def check_counterfactuals(model: ModelModule,

Parameters
----------
model: ModelModule
model: ModelObject
The model module containing the trained model and its configuration.
counterfactuals: pd.DataFrame
The generated counterfactuals to be checked.
Expand Down
Loading