-
Notifications
You must be signed in to change notification settings - Fork 110
Add new forward model configuration style #10597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add new forward model configuration style #10597
Conversation
CodSpeed Performance ReportMerging #10597 will not alter performanceComparing Summary
|
d436308
to
c08e84f
Compare
f4edd63
to
7ccd960
Compare
89fd199
to
2e08023
Compare
2e08023
to
3d84617
Compare
3943605
to
bfae075
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Some comments which might simplify it a bit
src/everest/config/everest_config.py
Outdated
def consolidate_forward_model_formats( | ||
cls, values: dict[str, Any] | ||
) -> dict[str, Any]: | ||
def format_fm(fm: str | dict[str, Any]) -> dict[str, Any]: | ||
if isinstance(fm, dict): | ||
return fm | ||
|
||
return {"job": fm, "results": None} | ||
|
||
if "forward_model" in values: | ||
values["forward_model"] = [format_fm(fm) for fm in values["forward_model"]] | ||
|
||
return values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be done with a field validator on forward_model
:
class B(BaseModel):
a: str
b: int
class A(BaseModel):
a: B
@field_validator("a", mode="before")
@classmethod
def transform(cls, raw: str | dict[str, int]):
if isinstance(raw, str):
a, b = raw.split(",")
return {"a": a, "b": b}
return raw
src/everest/config/everest_config.py
Outdated
if job_name not in installed_jobs_name: | ||
errors.append(f"unknown job {job_name}") | ||
|
||
if len(errors) > 0: # Note: python3.11 ExceptionGroup will solve this nicely | ||
raise ValueError(errors) | ||
return self | ||
|
||
def get_forward_model_steps( | ||
self, result_type: Literal["gen_data", "summary"] | ||
) -> list[ForwardModelStepConfigWithResults]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also be simplified if you add a field validator on forward_model
I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this and moved filtering to the 2 callsites, behavior wrt field validator remains the same though.
@model_validator(mode="after") | ||
def validate_at_most_one_summary_forward_model(self, _: ValidationInfo) -> Self: | ||
if self.forward_model is None: | ||
return self | ||
|
||
summary_fms = [ | ||
fm | ||
for fm in self.forward_model | ||
if isinstance(fm, ForwardModelStepConfig) | ||
and fm.results is not None | ||
and fm.results.type == "summary" | ||
] | ||
if len(summary_fms) > 1: | ||
raise ValueError( | ||
f"Found ({len(summary_fms)}) " | ||
f"forward model steps producing summary data. " | ||
f"Only one summary-producing forward model step is supported." | ||
) | ||
|
||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this might be cleaner if it is located with ForwardModelStepConfig
?
class ForwardModelStepConfig(BaseModelWithContextSupport):
job: str = Field(
description="Name of the forward model step",
)
results: ForwardModelStepResultsConfig | None = Field(
default=None, description="Result file produced by the forward model"
)
nr_summary: ClassVar[int] = 0
@field_validator("results")
def check_nr_summary(self, v: Any):
if v is not None and v.type == "summary":
if self.nr_summary == 0:
self.nr_summary += 1
elif self.nr_summary > 0:
raise ValueError("Can only have one summary result")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I can see it would work to have a classvar to count the number of summary initializations to mimic doing a validator on the list, but not sure if it is more clean/easy to work with? Like if we directly or indirectly initialize multiple configs, the counter will keep going up, so it would have to be forced to be scoped within each EverestConfig instance.
src/everest/config/everest_config.py
Outdated
@model_validator(mode="before") | ||
@classmethod | ||
def validate_no_data_file(cls, values: dict[str, Any]) -> dict[str, Any]: | ||
if "model" in values and "data_file" in values["model"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably be on ModelConfig
:
´´´python
@model_validator(mode="before")
@classmethod
def deprecate_key(cls, values: Any) -> Any:
if "data_file" in values:
raise ...
though we might want cross validation between `eclbase` and `model.data_file` 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now that it only looks on data_file it should be moved, but I also think some cross validation makes sense, so I added that here
38fe743
to
cee48fb
Compare
cee48fb
to
50275f0
Compare
Issue
Resolves #9615