Skip to content

Do not sample dm #10393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pydantic import PositiveFloat, ValidationError

from .analysis_module import ESSettings
from .design_matrix import DesignMatrix
from .parsing import (
ConfigDict,
ConfigKeys,
Expand All @@ -37,7 +36,6 @@ class AnalysisConfig:
es_module: ESSettings = field(default_factory=ESSettings)
observation_settings: UpdateSettings = field(default_factory=UpdateSettings)
num_iterations: int = 1
design_matrix: DesignMatrix | None = None

@no_type_check
@classmethod
Expand Down Expand Up @@ -77,8 +75,6 @@ def from_dict(cls, config_dict: ConfigDict) -> AnalysisConfig:

min_realization = min(min_realization, num_realization)

design_matrix_config_lists = config_dict.get(ConfigKeys.DESIGN_MATRIX, [])

options: dict[str, dict[str, Any]] = {"STD_ENKF": {}}
observation_settings: dict[str, Any] = {
"alpha": config_dict.get(ConfigKeys.ENKF_ALPHA, 3.0),
Expand Down Expand Up @@ -176,21 +172,11 @@ def from_dict(cls, config_dict: ConfigDict) -> AnalysisConfig:
if all_errors:
raise ConfigValidationError.from_collected(all_errors)

design_matrices = [
DesignMatrix.from_config_list(design_matrix_config_list)
for design_matrix_config_list in design_matrix_config_lists
]
design_matrix: DesignMatrix | None = None
if design_matrices:
design_matrix = design_matrices[0]
for dm_other in design_matrices[1:]:
design_matrix.merge_with_other(dm_other)
config = cls(
minimum_required_realizations=min_realization,
update_log_path=config_dict.get(ConfigKeys.UPDATE_LOG_PATH, "update_log"),
observation_settings=obs_settings,
es_module=es_settings,
design_matrix=design_matrix,
)
return config

Expand Down
6 changes: 4 additions & 2 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DesignMatrix:
xls_filename: Path
design_sheet: str
default_sheet: str
group_name: str = DESIGN_MATRIX_GROUP

def __post_init__(self) -> None:
try:
Expand Down Expand Up @@ -110,7 +111,7 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:

def merge_with_existing_parameters(
self, existing_parameters: list[ParameterConfig]
) -> tuple[list[ParameterConfig], GenKwConfig]:
) -> list[ParameterConfig]:
"""
This method merges the design matrix parameters with the existing parameters and
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
Expand Down Expand Up @@ -146,6 +147,7 @@ def merge_with_existing_parameters(
)

design_parameter_group.name = parameter_group.name
self.group_name = parameter_group.name
design_parameter_group.template_file = parameter_group.template_file
design_parameter_group.output_file = parameter_group.output_file
design_group_added = True
Expand All @@ -157,7 +159,7 @@ def merge_with_existing_parameters(
)
else:
new_param_config += [parameter_group]
return new_param_config, design_parameter_group
return [*new_param_config, design_parameter_group]

def read_and_validate_design_matrix(
self,
Expand Down
18 changes: 18 additions & 0 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ert.field_utils import get_shape

from .design_matrix import DesignMatrix
from .ext_param_config import ExtParamConfig
from .field import Field as FieldConfig
from .gen_data_config import GenDataConfig
Expand Down Expand Up @@ -50,6 +51,7 @@ class EnsembleConfig:
str, GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig
] = field(default_factory=dict)
refcase: Refcase | None = None
design_matrix: DesignMatrix | None = None

def __post_init__(self) -> None:
self._check_for_duplicate_names(
Expand Down Expand Up @@ -110,6 +112,7 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
gen_kw_list = config_dict.get(ConfigKeys.GEN_KW, [])
surface_list = config_dict.get(ConfigKeys.SURFACE, [])
field_list = config_dict.get(ConfigKeys.FIELD, [])
design_matrix_lists = config_dict.get(ConfigKeys.DESIGN_MATRIX, [])
dims = None
if grid_file_path is not None:
try:
Expand Down Expand Up @@ -139,6 +142,20 @@ def make_field(field_list: list[str]) -> FieldConfig:
+ [make_field(f) for f in field_list]
)

design_matrices = [
DesignMatrix.from_config_list(design_matrix_list)
for design_matrix_list in design_matrix_lists
]
design_matrix: DesignMatrix | None = None
if design_matrices:
design_matrix = design_matrices[0]
for dm_other in design_matrices[1:]:
design_matrix.merge_with_other(dm_other)
if design_matrix is not None:
parameter_configs = design_matrix.merge_with_existing_parameters(
parameter_configs
)

response_configs: list[ResponseConfig] = []

for config_cls in _KNOWN_RESPONSE_TYPES:
Expand All @@ -156,6 +173,7 @@ def make_field(field_list: list[str]) -> FieldConfig:
parameter.name: parameter for parameter in parameter_configs
},
refcase=refcase,
design_matrix=design_matrix,
)

def __getitem__(self, key: str) -> ParameterConfig | ResponseConfig:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def from_dict(cls, config_dict) -> Self:
if errors:
raise ConfigValidationError.from_collected(errors)

if dm := analysis_config.design_matrix:
if dm := ensemble_config.design_matrix:
dm_params = [
x.name
for x in dm.parameter_configuration.transform_function_definitions
Expand Down
5 changes: 2 additions & 3 deletions src/ert/gui/simulation/ensemble_experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from PyQt6.QtCore import pyqtSlot as Slot
from PyQt6.QtWidgets import QFormLayout, QLabel, QWidget

from ert.config import AnalysisConfig
from ert.config import DesignMatrix
from ert.gui.ertnotifier import ErtNotifier
from ert.gui.ertwidgets import (
ActiveRealizationsModel,
Expand All @@ -31,10 +31,10 @@ class Arguments:
class EnsembleExperimentPanel(ExperimentConfigPanel):
def __init__(
self,
analysis_config: AnalysisConfig,
ensemble_size: int,
run_path: str,
notifier: ErtNotifier,
design_matrix: DesignMatrix | None,
):
super().__init__(EnsembleExperiment)
self.notifier = notifier
Expand Down Expand Up @@ -82,7 +82,6 @@ def __init__(
)
layout.addRow("Active realizations", self._active_realizations_field)

design_matrix = analysis_config.design_matrix
if design_matrix is not None:
layout.addRow(
"Design Matrix",
Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/simulation/ensemble_smoother_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .experiment_config_panel import ExperimentConfigPanel

if TYPE_CHECKING:
from ert.config import AnalysisConfig
from ert.config import AnalysisConfig, DesignMatrix


@dataclass
Expand All @@ -45,6 +45,7 @@ def __init__(
run_path: str,
notifier: ErtNotifier,
ensemble_size: int,
design_matrix: DesignMatrix | None,
) -> None:
super().__init__(EnsembleSmoother)
self.notifier = notifier
Expand Down Expand Up @@ -94,7 +95,6 @@ def __init__(
self._active_realizations_field.setValidator(RangeStringArgument(ensemble_size))
layout.addRow("Active realizations", self._active_realizations_field)

design_matrix = analysis_config.design_matrix
if design_matrix is not None:
layout.addRow(
"Design Matrix",
Expand Down
9 changes: 6 additions & 3 deletions src/ert/gui/simulation/experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def __init__(
True,
)
analysis_config = config.analysis_config
design_matrix = config.ensemble_config.design_matrix
self.addExperimentConfigPanel(
EnsembleExperimentPanel(analysis_config, ensemble_size, run_path, notifier),
EnsembleExperimentPanel(ensemble_size, run_path, notifier, design_matrix),
True,
)
self.addExperimentConfigPanel(
Expand All @@ -148,12 +149,14 @@ def __init__(

self.addExperimentConfigPanel(
MultipleDataAssimilationPanel(
analysis_config, run_path, notifier, ensemble_size
analysis_config, run_path, notifier, ensemble_size, design_matrix
),
experiment_type_valid,
)
self.addExperimentConfigPanel(
EnsembleSmootherPanel(analysis_config, run_path, notifier, ensemble_size),
EnsembleSmootherPanel(
analysis_config, run_path, notifier, ensemble_size, design_matrix
),
experiment_type_valid,
)
self.addExperimentConfigPanel(
Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/simulation/multiple_data_assimilation_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .experiment_config_panel import ExperimentConfigPanel

if TYPE_CHECKING:
from ert.config import AnalysisConfig
from ert.config import AnalysisConfig, DesignMatrix
from ert.gui.ertwidgets import ValueModel


Expand All @@ -53,6 +53,7 @@ def __init__(
run_path: str,
notifier: ErtNotifier,
ensemble_size: int,
design_matrix: DesignMatrix | None,
) -> None:
super().__init__(MultipleDataAssimilation)
self.notifier = notifier
Expand Down Expand Up @@ -136,7 +137,6 @@ def __init__(
self.simulationConfigurationChanged
)

design_matrix = analysis_config.design_matrix
if design_matrix is not None:
layout.addRow(
"Design Matrix",
Expand Down
37 changes: 17 additions & 20 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import numpy as np

from ert.config import ConfigValidationError, HookRuntime
from ert.config import HookRuntime
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
self.experiment: Experiment | None = None
self.ensemble: Ensemble | None = None

self._design_matrix = config.analysis_config.design_matrix
self._design_matrix = config.ensemble_config.design_matrix
self._observations = config.observations
self._parameter_configuration = config.ensemble_config.parameter_configuration
self._response_configuration = config.ensemble_config.response_configuration
Expand Down Expand Up @@ -80,18 +80,18 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
# If design matrix is present, we try to merge design matrix parameters
# to the experiment parameters and set new active realizations

parameters_config = self._parameter_configuration
design_matrix = self._design_matrix
design_matrix_group = None
if design_matrix is not None:
try:
parameters_config, design_matrix_group = (
design_matrix.merge_with_existing_parameters(parameters_config)
)
except ConfigValidationError as exc:
raise ErtRunError(str(exc)) from exc
params_to_sample = (
[
param.name
for param in parameters_config
if param.name != design_matrix.group_name
]
if design_matrix is not None
else None
)

if not restart:
self.run_workflows(
Expand All @@ -100,11 +100,7 @@ def run_experiment(
)
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=(
[*parameters_config, design_matrix_group]
if design_matrix_group is not None
else parameters_config
),
parameters=parameters_config,
observations=self._observations,
responses=self._response_configuration,
)
Expand All @@ -131,15 +127,16 @@ def run_experiment(
sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
parameters=params_to_sample,
random_seed=self.random_seed,
)

if design_matrix_group is not None and design_matrix is not None:
if design_matrix is not None:
save_design_matrix_to_ensemble(
design_matrix.design_matrix_df,
self.ensemble,
np.where(self.active_realizations)[0],
design_matrix_group.name,
design_matrix.group_name,
)

self._evaluate_and_postprocess(
Expand Down
Loading