Skip to content

Commit e85b038

Browse files
authored
Merge pull request #50 from cseptesting/47-consolidate-docker-environment-manager-for-models
Added DockerManager for executing models
2 parents 8ed04e6 + 77f85f8 commit e85b038

File tree

20 files changed

+7854
-140
lines changed

20 files changed

+7854
-140
lines changed

floatcsep/infrastructure/environments.py

+113-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from abc import ABC, abstractmethod
1010
from typing import Union
1111

12+
import docker
13+
from docker.errors import ImageNotFound, NotFound, APIError
1214
from packaging.specifiers import SpecifierSet
1315

1416
log = logging.getLogger("floatLogger")
@@ -387,20 +389,123 @@ class DockerManager(EnvironmentManager):
387389
"""
388390

389391
def __init__(self, base_name: str, model_directory: str) -> None:
390-
self.base_name = base_name
392+
self.base_name = base_name.replace(" ", "_")
391393
self.model_directory = model_directory
392394

393-
def create_environment(self, force=False) -> None:
394-
pass
395+
# use a lower-case slug for tags
396+
slug = self.base_name.lower()
397+
self.image_tag = f"{slug}_image"
398+
self.container_name = f"{slug}_container"
395399

396-
def env_exists(self) -> None:
397-
pass
400+
# Docker SDK client
401+
self.client = docker.from_env()
398402

399-
def run_command(self, command) -> None:
400-
pass
403+
def create_environment(self, force: bool = False) -> None:
404+
"""
405+
Build (or rebuild) the Docker image for this model.
406+
"""
407+
408+
# If forced, remove the existing image
409+
if force and self.env_exists():
410+
log.info(f"[{self.base_name}] Removing existing image '{self.image_tag}'")
411+
try:
412+
self.client.images.remove(self.image_tag, force=True)
413+
except APIError as e:
414+
log.warning(f"[{self.base_name}] Could not remove image: {e}")
415+
416+
# If image is missing or rebuild was requested, build it now
417+
if force or not self.env_exists():
418+
build_path = os.path.abspath(self.model_directory)
419+
uid, gid = os.getuid(), os.getgid()
420+
build_args = {
421+
"USER_UID": str(uid),
422+
"USER_GID": str(gid),
423+
}
424+
log.info(f"[{self.base_name}] Building image '{self.image_tag}' from {build_path}")
425+
426+
build_logs = self.client.api.build(
427+
path=build_path,
428+
tag=self.image_tag,
429+
rm=True,
430+
decode=True,
431+
buildargs=build_args,
432+
nocache=False # todo: create model arg for --no-cache
433+
)
434+
435+
# Stream each chunk
436+
for chunk in build_logs:
437+
if "stream" in chunk:
438+
for line in chunk["stream"].splitlines():
439+
log.debug(f"[{self.base_name}][build] {line}")
440+
elif "errorDetail" in chunk:
441+
msg = chunk["errorDetail"].get("message", "").strip()
442+
log.error(f"[{self.base_name}][build error] {msg}")
443+
raise RuntimeError(f"Docker build error: {msg}")
444+
log.info(f"[{self.base_name}] Successfully built '{self.image_tag}'")
445+
446+
def env_exists(self) -> bool:
447+
"""
448+
Checks if the Docker image with the given tag already exists.
449+
450+
Returns:
451+
bool: True if the Docker image exists, False otherwise.
452+
"""
453+
"""
454+
Returns True if an image with our tag already exists locally.
455+
"""
456+
try:
457+
self.client.images.get(self.image_tag)
458+
return True
459+
except ImageNotFound:
460+
return False
461+
462+
def run_command(self, command=None) -> None:
463+
"""
464+
Runs the model’s Docker container with input/ and forecasts/ mounted.
465+
Streams logs and checks for non-zero exit codes.
466+
"""
467+
model_root = os.path.abspath(self.model_directory)
468+
mounts = {
469+
os.path.join(model_root, "input"): {'bind': '/app/input', 'mode': 'rw'},
470+
os.path.join(model_root, "forecasts"): {'bind': '/app/forecasts', 'mode': 'rw'},
471+
}
472+
473+
uid, gid = os.getuid(), os.getgid()
474+
475+
log.info(f"[{self.base_name}] Launching container {self.container_name}")
476+
477+
try:
478+
container = self.client.containers.run(
479+
self.image_tag,
480+
remove=False,
481+
volumes=mounts,
482+
detach=True,
483+
user=f"{uid}:{gid}",
484+
)
485+
except docker.errors.APIError as e:
486+
raise RuntimeError(f"[{self.base_name}] Failed to start container: {e}")
487+
488+
# Log output live
489+
for line in container.logs(stream=True):
490+
log.info(f"[{self.base_name}] {line.decode().rstrip()}")
491+
492+
# Wait for exit
493+
exit_code = container.wait().get("StatusCode", 1)
494+
495+
# Clean up
496+
container.remove(force=True)
497+
498+
if exit_code != 0:
499+
raise RuntimeError(f"[{self.base_name}] Container exited with code {exit_code}")
500+
501+
log.info(f"[{self.base_name}] Container finished successfully.")
401502

402503
def install_dependencies(self) -> None:
403-
pass
504+
"""
505+
Installs dependencies for Docker-based models. This is typically handled by the Dockerfile,
506+
so no additional action is needed here.
507+
"""
508+
log.info("No additional dependency installation required for Docker environments.")
404509

405510

406511
class EnvironmentFactory:

floatcsep/infrastructure/registries.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
database: str = None,
9494
args_file: str = None,
9595
input_cat: str = None,
96+
fmt: str = None,
9697
) -> None:
9798
"""
9899
@@ -111,6 +112,8 @@ def __init__(
111112
self.input_cat = input_cat
112113
self.forecasts = {}
113114

115+
self._fmt = fmt
116+
114117
def get(self, *args: Sequence[str]) -> str:
115118
"""
116119
Args:
@@ -161,7 +164,11 @@ def fmt(self) -> str:
161164
if self.database:
162165
return os.path.splitext(self.database)[1][1:]
163166
else:
164-
return os.path.splitext(self.path)[1][1:]
167+
ext = os.path.splitext(self.path)[1][1:]
168+
if ext:
169+
return ext
170+
else:
171+
return self._fmt
165172

166173
def as_dict(self) -> dict:
167174
"""
@@ -199,7 +206,7 @@ def build_tree(
199206
model_class: str = "TimeIndependentModel",
200207
prefix: str = None,
201208
args_file: str = None,
202-
input_cat: str = None,
209+
input_cat: str = None
203210
) -> None:
204211
"""
205212
Creates the run directory, and reads the file structure inside.
@@ -210,6 +217,7 @@ def build_tree(
210217
prefix (str): prefix of the model forecast filenames if TD
211218
args_file (str, bool): input arguments path of the model if TD
212219
input_cat (str, bool): input catalog path of the model if TD
220+
fmt (str, bool): for time dependent mdoels
213221
214222
"""
215223

@@ -235,7 +243,7 @@ def build_tree(
235243

236244
# set forecast names
237245
self.forecasts = {
238-
win: join(dirtree["forecasts"], f"{prefix}_{win}.csv") for win in windows
246+
win: join(dirtree["forecasts"], f"{prefix}_{win}.{self.fmt}") for win in windows
239247
}
240248

241249
def log_tree(self) -> None:

floatcsep/model.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Callable, Union, Sequence
77

88
import git
9+
import yaml
910
from csep.core.forecasts import GriddedForecast, CatalogForecast
1011

1112
from floatcsep.utils.accessors import from_zenodo, from_git
@@ -106,6 +107,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non
106107
from_git(
107108
giturl,
108109
self.registry.dir if self.registry.fmt else self.registry.path,
110+
force=self.force_stage,
109111
**kwargs,
110112
)
111113
except (git.NoSuchPathError, git.CommandError) as msg:
@@ -122,7 +124,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non
122124
f"structure"
123125
)
124126

125-
def as_dict(self, excluded=("name", "repository", "workdir")):
127+
def as_dict(self, excluded=("name", "repository", "workdir", "environment")):
126128
"""
127129
Returns:
128130
Dictionary with relevant attributes. Model can be re-instantiated from this dict
@@ -295,6 +297,7 @@ def __init__(
295297
model_path: str,
296298
func: Union[str, Callable] = None,
297299
func_kwargs: dict = None,
300+
fmt: str = 'csv',
298301
**kwargs,
299302
) -> None:
300303
"""
@@ -317,7 +320,9 @@ def __init__(
317320
self.func = func
318321
self.func_kwargs = func_kwargs or {}
319322

320-
self.registry = ForecastRegistry(kwargs.get("workdir", os.getcwd()), model_path)
323+
self.registry = ForecastRegistry(workdir=kwargs.get("workdir", os.getcwd()),
324+
path=model_path,
325+
fmt=fmt)
321326
self.repository = ForecastRepository.factory(
322327
self.registry, model_class=self.__class__.__name__, **kwargs
323328
)
@@ -451,3 +456,32 @@ def replace_arg(arg, val, fp):
451456

452457
with open(filepath, "w") as file_:
453458
json.dump(args, file_, indent=2)
459+
460+
elif fmt == ".yml" or fmt == ".yaml":
461+
462+
def nested_update(dest: dict, src: dict, max_depth: int = 3, _level: int = 1):
463+
"""
464+
Recursively update dest with values from src down to max_depth levels.
465+
- If dest[k] and src[k] are both dicts, recurse (until max_depth).
466+
- Otherwise overwrite dest[k] with src[k].
467+
"""
468+
for key, val in src.items():
469+
if (
470+
_level < max_depth
471+
and key in dest
472+
and isinstance(dest[key], dict)
473+
and isinstance(val, dict)
474+
):
475+
nested_update(dest[key], val, max_depth, _level + 1)
476+
else:
477+
dest[key] = val
478+
479+
480+
with open(filepath, "r") as file_:
481+
args = yaml.safe_load(file_)
482+
args["start_date"] = start.isoformat()
483+
args["end_date"] = end.isoformat()
484+
485+
nested_update(args, self.func_kwargs)
486+
with open(filepath, "w") as file_:
487+
yaml.safe_dump(args, file_, indent=2)

floatcsep/utils/accessors.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
import sys
66
import shutil
77

8-
HOST_CATALOG = "https://service.iris.edu/fdsnws/event/1/query?"
9-
TIMEOUT = 180
10-
11-
128
def from_zenodo(record_id, folder, force=False):
139
"""
1410
Download data from a Zenodo repository.
@@ -50,7 +46,7 @@ def from_zenodo(record_id, folder, force=False):
5046
sys.exit(-1)
5147

5248

53-
def from_git(url, path, branch=None, depth=1, **kwargs):
49+
def from_git(url, path, branch=None, depth=1, force=False, **kwargs):
5450
"""
5551
Clones a shallow repository from a git url.
5652
@@ -59,6 +55,7 @@ def from_git(url, path, branch=None, depth=1, **kwargs):
5955
path (str): path/folder where to clone the repo
6056
branch (str): repository's branch to clone (default: main)
6157
depth (int): depth history of commits
58+
force (bool): If True, deletes existing path before cloning
6259
**kwargs: keyword args passed to Repo.clone_from
6360
6461
Returns:
@@ -68,6 +65,13 @@ def from_git(url, path, branch=None, depth=1, **kwargs):
6865
kwargs.update({"depth": depth})
6966
git.refresh()
7067

68+
if os.path.exists(path):
69+
if force:
70+
shutil.rmtree(path)
71+
elif os.listdir(path):
72+
raise ValueError(f"Cannot clone into non-empty directory: {path}")
73+
os.makedirs(path, exist_ok=True)
74+
7175
try:
7276
repo = git.Repo(path)
7377
except (git.NoSuchPathError, git.InvalidGitRepositoryError):

floatcsep/utils/helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ def timewindows_ti(
318318
timelimits = pandas.date_range(
319319
start=start_date, end=end_date, periods=periods, freq=frequency
320320
)
321-
print(timelimits)
322321
timelimits = timelimits.to_pydatetime()
323322
except ValueError as e_:
324323
raise ValueError(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
FROM ubuntu:22.04
2+
CMD ["/bin/false"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
FROM nonexistingimage:latest
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
FROM ubuntu:22.04
2+
CMD ["asd"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
FROM ubuntu:22.04
2+
CMD ["touch", "/root/forbidden"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
FROM ubuntu:22.04
2+
CMD ["echo", "Hello from valid container"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FROM ubuntu:22.04
2+
RUN useradd -u 1234 -m modeler
3+
USER modeler
4+
CMD ["id"]

0 commit comments

Comments
 (0)