-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathprod_env.py
More file actions
119 lines (98 loc) · 4.41 KB
/
prod_env.py
File metadata and controls
119 lines (98 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os.path
from typing import Optional
from pydantic import BaseModel, PositiveFloat, model_validator
from nvflare.apis.utils.format_check import name_check
from nvflare.job_config.api import FedJob
from nvflare.recipe.spec import ExecEnv
from nvflare.recipe.utils import _collect_non_local_scripts
from .session_mgr import SessionManager
logger = logging.getLogger(__name__)
DEFAULT_ADMIN_USER = "admin@nvidia.com"
# Internal — not part of the public API
class _ProdEnvValidator(BaseModel):
startup_kit_location: str
login_timeout: PositiveFloat = 5.0
username: str = DEFAULT_ADMIN_USER
project: Optional[str] = None
@model_validator(mode="after")
def check_startup_kit_location_exists(self) -> "_ProdEnvValidator":
if not os.path.exists(self.startup_kit_location):
raise ValueError(f"startup_kit_location path does not exist: {self.startup_kit_location}")
if self.project is not None:
err, reason = name_check(self.project, "project")
if err:
raise ValueError(reason)
return self
class ProdEnv(ExecEnv):
def __init__(
self,
startup_kit_location: str,
login_timeout: float = 5.0,
username: str = DEFAULT_ADMIN_USER,
project: Optional[str] = None,
extra: Optional[dict] = None,
):
"""Production execution environment for submitting and monitoring NVFlare jobs.
This environment uses the startup kit of an NVFlare deployment to submit jobs via the Flare API.
Args:
startup_kit_location (str): Path to the admin's startup kit directory.
login_timeout (float): Timeout (in seconds) for logging into the Flare API session. Must be > 0.
username (str): Username to log in with.
project (Optional[str]): Project name to tag submitted/cloned jobs.
extra: extra env info.
"""
super().__init__(extra)
v = _ProdEnvValidator(
startup_kit_location=startup_kit_location,
login_timeout=login_timeout,
username=username,
project=project,
)
self.startup_kit_location = v.startup_kit_location
self.login_timeout = v.login_timeout
self.username = v.username
self.project = v.project
self._session_manager = None # Lazy initialization
def get_job_status(self, job_id: str) -> Optional[str]:
return self._get_session_manager().get_job_status(job_id)
def abort_job(self, job_id: str) -> None:
self._get_session_manager().abort_job(job_id)
def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
return self._get_session_manager().get_job_result(job_id, timeout)
def deploy(self, job: FedJob):
"""Deploy a job using SessionManager."""
# Log warnings for non-local scripts (assumed pre-installed on production)
non_local_scripts = _collect_non_local_scripts(job)
for script in non_local_scripts:
logger.warning(
f"Script '{script}' not found locally. " f"Assuming it is pre-installed on the production system."
)
try:
return self._get_session_manager().submit_job(job)
except Exception as e:
raise RuntimeError(f"Failed to submit job via Flare API: {e}")
def _get_session_manager(self):
"""Get or create SessionManager with lazy initialization."""
if self._session_manager is None:
session_params = {
"username": self.username,
"startup_kit_location": self.startup_kit_location,
"timeout": self.login_timeout,
"project": self.project,
}
self._session_manager = SessionManager(session_params)
return self._session_manager