Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/rra_population_model/cli_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def with_version[**P, T]() -> Callable[[Callable[P, T]], Callable[P, T]]:
)


def with_copy_from_version[**P, T]() -> Callable[[Callable[P, T]], Callable[P, T]]:
return click.option(
"--copy-from-version",
type=click.STRING,
help="Version of the model to copy predictions from. Used if we're "
"raking a set of predictions to multiple raking targets.",
)


def with_block_key[**P, T]() -> Callable[[Callable[P, T]], Callable[P, T]]:
return click.option(
"--block-key",
Expand Down
14 changes: 12 additions & 2 deletions src/rra_population_model/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import warnings
from enum import StrEnum
from pathlib import Path
Expand Down Expand Up @@ -26,7 +27,7 @@ def to_list(cls) -> list[str]:

class BuiltVersion(BaseModel):
provider: Literal["ghsl", "microsoft"]
version: Literal["v4", "r2023a"]
version: Literal["v4", "v6", "r2023a"]
time_points: list[str]
measures: list[str]
denominators: list[str]
Expand All @@ -47,7 +48,7 @@ def time_points_float(self) -> list[float]:
@model_validator(mode="after")
def validate_version(self) -> Self:
version_map = {
"microsoft": ["v4"],
"microsoft": ["v4", "v6"],
"ghsl": ["r2023a"],
}
allowed_version = version_map[self.provider]
Expand Down Expand Up @@ -103,6 +104,15 @@ def validate_measures(self) -> Self:
measures=["density"],
denominators=["density"],
),
"microsoft_v6": BuiltVersion(
provider="microsoft",
version="v6",
time_points=[
f"{y}q{q}" for y, q in itertools.product(range(2020, 2024), range(1, 5))
][1:],
measures=["density"],
denominators=["density"],
),
}

DENOMINATORS = []
Expand Down
96 changes: 90 additions & 6 deletions src/rra_population_model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def model_root(self, resolution: str) -> Path:
def model_version_root(self, resolution: str, version: str) -> Path:
return self.model_root(resolution) / version

def make_model_version_root(self, resolution: str, version: str) -> Path:
def make_model_version_root(
self, resolution: str, version: str, *, exist_ok: bool = False
) -> Path:
dirs = [
self.model_version_root(resolution, version),
self.raw_predictions_root(resolution, version),
Expand All @@ -591,16 +593,95 @@ def make_model_version_root(self, resolution: str, version: str) -> Path:
self.compiled_predictions_root(resolution, version),
]
for path in dirs:
mkdir(path)
mkdir(path, exist_ok=exist_ok)

return self.model_version_root(resolution, version)

def maybe_copy_version(
self,
resolution: str,
version: str,
copy_from_version: str | None,
) -> None:
if copy_from_version is None:
# Nothing to do
return

version_root = self.model_version_root(resolution, version)
version_spec_path = self.model_specification_path(resolution, version)
version_ckpt_path = self.model_checkpoint_path(resolution, version)
version_preds_root = self.raw_predictions_root(resolution, version)

copy_from_root = self.model_version_root(resolution, copy_from_version)
copy_spec_path = self.model_specification_path(resolution, copy_from_version)
copy_ckpt_path = self.model_checkpoint_path(resolution, copy_from_version)
copy_preds_root = self.raw_predictions_root(resolution, copy_from_version)

#################
# Preconditions #
#################
# Basic checks
if copy_from_version >= version:
msg = "Cannot copy from a version that is the same or newer."
raise ValueError(msg)

return path
if not copy_from_root.exists():
msg = f"Cannot copy from non-existent version {copy_from_version}"
raise ValueError(msg)

if copy_preds_root.is_symlink():
msg = f"Cannot copy from symlinked raw predictions root {copy_preds_root}"
raise ValueError(msg)

if version_root.exists() and not version_spec_path.exists():
msg = f"Version {version} exists but has no specification file. This is an invalid directory state."
raise ValueError(msg)

# Have we already copied this version?
if version_spec_path.exists():
model_matches = version_ckpt_path.exists() and version_ckpt_path.samefile(
copy_ckpt_path
)
predictions_match = (
version_preds_root.exists()
and version_preds_root.samefile(copy_preds_root)
)
if model_matches and predictions_match:
# We've already copied this version, we'll make this a no-op
return
else:
msg = f"Version {version} already exists but does not match copy-from version {copy_from_version}"
raise ValueError(msg)

# If we're here, everything should be safe for copying
# Generate the new version directory
self.make_model_version_root(resolution, version)

# Copy the model spec
copy_spec = yaml.safe_load(copy_spec_path.read_text())
copy_spec["model_version"] = version
copy_spec["output_root"] = str(version_root)
with version_spec_path.open("w") as f:
yaml.safe_dump(copy_spec, f)

# Symlink the model checkpoint.
version_ckpt_path.symlink_to(copy_ckpt_path.resolve())

# Symlink the raw predictions root
version_preds_root.rmdir() # We've just made this as an empty directory
version_preds_root.symlink_to(copy_preds_root)

def model_specification_path(self, resolution: str, version: str) -> Path:
return self.model_version_root(resolution, version) / "specification.yaml"

def save_model_specification(
self,
model_spec: "ModelSpecification",
) -> None:
self.make_model_version_root(model_spec.resolution, model_spec.model_version)
path = Path(model_spec.output_root) / "specification.yaml"
path = self.model_specification_path(
model_spec.resolution, model_spec.model_version
)
touch(path)
with path.open("w") as f:
yaml.safe_dump(model_spec.model_dump(mode="json"), f)
Expand All @@ -610,15 +691,18 @@ def load_model_specification(
) -> "ModelSpecification":
from rra_population_model.model.modeling.datamodel import ModelSpecification

path = self.model_version_root(resolution, version) / "specification.yaml"
path = self.model_specification_path(resolution, version)
with path.open() as f:
spec = yaml.safe_load(f)
return ModelSpecification.model_validate(spec)

def model_checkpoint_path(self, resolution: str, version: str) -> Path:
return self.model_version_root(resolution, version) / "best_model.ckpt"

def load_model(self, resolution: str, version: str) -> "PPSModel":
from rra_population_model.model.modeling.model import PPSModel

ckpt_path = self.model_version_root(resolution, version) / "best_model.ckpt"
ckpt_path = self.model_checkpoint_path(resolution, version)
return PPSModel.load_from_checkpoint(ckpt_path)

def raw_predictions_root(self, resolution: str, version: str) -> Path:
Expand Down
7 changes: 5 additions & 2 deletions src/rra_population_model/model/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,23 @@ def inference_main(
for block_key in block_keys
if not pm_data.raw_prediction_path(block_key, time_point, model_spec).exists()
]
if not block_keys:
print("All blocks have been predicted.")
return

datamodule = InferenceDataModule(
model_spec.model_dump(),
block_keys,
time_point,
num_workers=4,
num_workers=0,
)
pred_writer = CustomWriter(
pm_data, model.specification, time_point, write_interval="batch"
)
trainer = Trainer(
callbacks=[pred_writer],
enable_progress_bar=progress_bar,
devices=2,
devices=1,
)
trainer.predict(model, datamodule, return_predictions=False)

Expand Down
3 changes: 3 additions & 0 deletions src/rra_population_model/model_prep/features/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def features(
modeling_frame = pm_data.load_modeling_frame(resolution)
block_keys = modeling_frame.block_key.unique().tolist()

njobs = len(block_keys) * len(time_point)
print(f"Submitting {njobs} jobs to process features")

jobmon.run_parallel(
runner="pmtask model_prep",
task_name="features",
Expand Down
Loading
Loading