Skip to content

Commit e244c0d

Browse files
authored
Merge pull request #732 from NVIDIA/am/system-inheritance
Convert base System into pydantic model
2 parents 84b7161 + 0bbccf4 commit e244c0d

File tree

6 files changed

+24
-131
lines changed

6 files changed

+24
-131
lines changed

src/cloudai/_core/system.py

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,67 +18,27 @@
1818
import logging
1919
from abc import ABC, abstractmethod
2020
from pathlib import Path
21-
from typing import TYPE_CHECKING, Any, Dict, Optional
21+
from typing import TYPE_CHECKING, Any
22+
23+
from pydantic import BaseModel, ConfigDict, Field
2224

2325
from .installables import Installable
2426

2527
if TYPE_CHECKING:
2628
from .base_job import BaseJob
2729

2830

29-
class System(ABC):
30-
"""
31-
Base class representing a generic system.
32-
33-
Attributes
34-
name (str): Unique name of the system.
35-
scheduler (str): Type of scheduler used by the system, determining the specific subclass of System to be used.
36-
install_path (Path): Installation path of CloudAI software.
37-
output_path (Path): Path to the output directory.
38-
global_env_vars (Optional[Dict[str, Any]]): Dictionary containing additional configuration settings for the
39-
system.
40-
monitor_interval (int): Interval in seconds for monitoring jobs.
41-
"""
42-
43-
def __init__(
44-
self,
45-
name: str,
46-
scheduler: str,
47-
install_path: Path,
48-
output_path: Path,
49-
global_env_vars: Optional[Dict[str, Any]] = None,
50-
monitor_interval: int = 1,
51-
) -> None:
52-
"""
53-
Initialize a System instance.
31+
class System(ABC, BaseModel):
32+
"""Base class representing a generic system."""
5433

55-
Args:
56-
name (str): Name of the system.
57-
scheduler (str): Type of scheduler used by the system.
58-
install_path (Path): The installation path of CloudAI.
59-
output_path (Path): Path to the output directory.
60-
global_env_vars (Optional[Dict[str, Any]]): Dictionary containing additional configuration settings for
61-
the system.
62-
monitor_interval (int): Interval in seconds for monitoring jobs.
63-
"""
64-
self.name = name
65-
self.scheduler = scheduler
66-
self.install_path = install_path
67-
self.output_path = output_path
68-
self.global_env_vars = global_env_vars if global_env_vars is not None else {}
69-
self.monitor_interval = monitor_interval
70-
71-
def __repr__(self) -> str:
72-
"""
73-
Provide a detailed string representation of the System instance, including all its attributes.
34+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
7435

75-
Returns
76-
str: String representation of the system including name, scheduler, output_path, and monitor_interval.
77-
"""
78-
return (
79-
f"System(name='{self.name}', scheduler='{self.scheduler}', output_path='{self.output_path}', "
80-
f"monitor_interval={self.monitor_interval})"
81-
)
36+
name: str
37+
scheduler: str
38+
install_path: Path
39+
output_path: Path
40+
global_env_vars: dict[str, Any] = Field(default_factory=dict)
41+
monitor_interval: int = 1
8242

8343
@abstractmethod
8444
def update(self) -> None:

src/cloudai/systems/kubernetes/kubernetes_system.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,18 @@
2727
import kubernetes as k8s
2828

2929

30-
from pydantic import BaseModel, ConfigDict
31-
3230
from cloudai.core import BaseJob, System
3331
from cloudai.util.lazy_imports import lazy
3432

3533
from .kubernetes_job import KubernetesJob
3634

3735

38-
class KubernetesSystem(BaseModel, System):
39-
"""
40-
Represents a Kubernetes system.
41-
42-
Attributes
43-
name (str): The name of the Kubernetes system.
44-
install_path (Path): Path to the installation directory.
45-
output_path (Path): Path to the output directory.
46-
kube_config_path (Path): Path to the Kubernetes config file.
47-
default_namespace (str): The default Kubernetes namespace for jobs.
48-
scheduler (str): The scheduler type, default is "kubernetes".
49-
global_env_vars (Dict[str, Any]): Global environment variables to be passed to jobs.
50-
monitor_interval (int): Time interval to monitor jobs, in seconds.
51-
_core_v1 (client.CoreV1Api): Kubernetes Core V1 API client instance.
52-
_batch_v1 (client.BatchV1Api): Kubernetes Batch V1 API client instance.
53-
_custom_objects_api (CustomObjectsApi): Kubernetes Custom Objects API client instance.
54-
"""
55-
56-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
57-
58-
name: str
59-
install_path: Path
60-
output_path: Path
36+
class KubernetesSystem(System):
37+
"""Represents a Kubernetes system."""
38+
6139
kube_config_path: Path
6240
default_namespace: str
6341
scheduler: str = "kubernetes"
64-
global_env_vars: Dict[str, Any] = {}
6542
monitor_interval: int = 1
6643
_core_v1: Optional[k8s.client.CoreV1Api] = None
6744
_batch_v1: Optional[k8s.client.BatchV1Api] = None

src/cloudai/systems/lsf/lsf_system.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818
from pathlib import Path
19-
from typing import Any, Dict, List, Optional, Tuple
19+
from typing import Dict, List, Optional, Tuple
2020

2121
from pydantic import BaseModel, ConfigDict, Field, field_serializer
2222

@@ -49,7 +49,7 @@ class LSFQueue(BaseModel):
4949
lsf_nodes: List[LSFNodeObj] = Field(default_factory=list, exclude=True)
5050

5151

52-
class LSFSystem(BaseModel, System):
52+
class LSFSystem(System):
5353
"""
5454
Represents an LSF system.
5555
@@ -69,14 +69,8 @@ class LSFSystem(BaseModel, System):
6969
cmd_shell (CommandShell): Command shell for executing system commands.
7070
"""
7171

72-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
73-
74-
name: str
75-
install_path: Path
76-
output_path: Path
7772
queues: List[LSFQueue] = Field(default_factory=list)
7873
account: Optional[str] = None
79-
global_env_vars: Dict[str, Any] = {}
8074
scheduler: str = "lsf"
8175
project_name: Optional[str] = None
8276
default_queue: Optional[str] = None

src/cloudai/systems/runai/runai_system.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pathlib import Path
1919
from typing import Any, Dict, List, Optional
2020

21-
from pydantic import BaseModel, ConfigDict, PrivateAttr
21+
from pydantic import Field, PrivateAttr
2222

2323
from cloudai.core import BaseJob, System
2424

@@ -31,7 +31,7 @@
3131
from .runai_training import ActualPhase, RunAITraining
3232

3333

34-
class RunAISystem(BaseModel, System):
34+
class RunAISystem(System):
3535
"""
3636
RunAISystem integrates with the RunAI platform to manage and monitor jobs and nodes.
3737
@@ -47,21 +47,15 @@ class RunAISystem(BaseModel, System):
4747
nodes (List[RunAINode]): List of nodes in the RunAI cluster.
4848
"""
4949

50-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
51-
52-
name: str
5350
scheduler: str = "runai"
54-
install_path: Path
55-
output_path: Path
56-
global_env_vars: Dict[str, Any] = {}
5751
monitor_interval: int = 60
5852
base_url: str
5953
user_email: str
6054
app_id: str
6155
app_secret: str
6256
project_id: str
6357
cluster_id: str
64-
nodes: List[RunAINode] = []
58+
nodes: List[RunAINode] = Field(default_factory=list)
6559
_api_client: Optional[RunAIRestClient] = PrivateAttr(default=None)
6660

6761
@property

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,32 +93,9 @@ class SlurmPartition(BaseModel):
9393
slurm_nodes: list[SlurmNode] = Field(default_factory=list[SlurmNode], exclude=True)
9494

9595

96-
class SlurmSystem(BaseModel, System):
97-
"""
98-
Represents a Slurm system.
99-
100-
Attributes
101-
output_path (Path): Path to the output directory.
102-
default_partition (str): The default partition for job submission.
103-
partitions (Dict[str, List[SlurmNode]]): Mapping of partition names to lists of SlurmNodes.
104-
account (Optional[str]): Account name for charging resources used by this job.
105-
distribution (Optional[str]): Specifies alternate distribution methods for remote processes.
106-
mpi (Optional[str]): Indicates the Process Management Interface (PMI) implementation to be used for
107-
inter-process communication.
108-
gpus_per_node (Optional[int]): Specifies the number of GPUs available per node.
109-
ntasks_per_node (Optional[int]): Specifies the number of tasks that can run concurrently on a single node.
110-
cache_docker_images_locally (bool): Whether to cache Docker images locally for the Slurm system.
111-
groups (Dict[str, Dict[str, List[SlurmNode]]]): Nested mapping where the key is the partition name and the
112-
value is another dictionary with group names as keys and lists of SlurmNodes as values, representing the
113-
group composition within each partition.
114-
cmd_shell (CommandShell): An instance of CommandShell for executing system commands.
115-
"""
96+
class SlurmSystem(System):
97+
"""Represents a Slurm system."""
11698

117-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
118-
119-
name: str
120-
install_path: Path
121-
output_path: Path
12299
default_partition: str
123100
partitions: List[SlurmPartition]
124101
account: Optional[str] = None
@@ -127,12 +104,11 @@ class SlurmSystem(BaseModel, System):
127104
gpus_per_node: Optional[int] = None
128105
ntasks_per_node: Optional[int] = None
129106
cache_docker_images_locally: bool = False
130-
global_env_vars: Dict[str, Any] = {}
131107
scheduler: str = "slurm"
132108
monitor_interval: int = 60
133109
cmd_shell: CommandShell = Field(default=CommandShell(), exclude=True)
134110
extra_srun_args: Optional[str] = None
135-
extra_sbatch_args: list[str] = []
111+
extra_sbatch_args: list[str] = Field(default_factory=list)
136112
supports_gpu_directives_cache: Optional[bool] = Field(default=None, exclude=True)
137113
container_mount_home: bool = False
138114

src/cloudai/systems/standalone/standalone_system.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,18 @@
1515
# limitations under the License.
1616

1717
import logging
18-
from pathlib import Path
19-
20-
from pydantic import BaseModel, ConfigDict
2118

2219
from cloudai.core import BaseJob, System
2320
from cloudai.util import CommandShell
2421

2522

26-
class StandaloneSystem(BaseModel, System):
23+
class StandaloneSystem(System):
2724
"""
2825
Class representing a Standalone system.
2926
3027
This class is used for systems that execute commands directly without a job scheduler.
3128
"""
3229

33-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
34-
35-
name: str
36-
install_path: Path
37-
output_path: Path
3830
scheduler: str = "standalone"
3931
monitor_interval: int = 1
4032
cmd_shell: CommandShell = CommandShell()

0 commit comments

Comments
 (0)