Skip to content

Commit 818187c

Browse files
author
yli1 user
committed
minor changes on command
1 parent bf490f1 commit 818187c

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

matrix/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .cli import main
2+
3+
if __name__ == "__main__":
4+
main()

matrix/cli.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def start_cluster(
8080
slurm: tp.Dict[str, tp.Union[str, int]] | None = None,
8181
local: tp.Dict[str, tp.Union[str, int]] | None = None,
8282
enable_grafana: bool = False,
83+
force_new_head: bool = False,
8384
):
8485
"""
8586
Starts the Ray cluster with additional keyword arguments. Only do this for new clusters.
@@ -97,11 +98,18 @@ def start_cluster(
9798
slurm (dict, optional): resources for slurm cluster.
9899
local (dict, optional): resources for local cluster.
99100
enable_grafana (bool, optional): If True, enable prometheus and grafana dashboard.
100-
101+
force_new_head (bool): force to remove head.json if haven't run 'matrix stop_cluster'.
102+
101103
Returns:
102104
None
103105
"""
104-
self.cluster.start(add_workers, slurm, local, enable_grafana=enable_grafana)
106+
self.cluster.start(
107+
add_workers,
108+
slurm,
109+
local,
110+
enable_grafana=enable_grafana,
111+
force_new_head=force_new_head,
112+
)
105113

106114
def stop_cluster(self):
107115
"""
@@ -370,7 +378,3 @@ def check_health(
370378

371379
def main():
372380
fire.Fire(Cli)
373-
374-
375-
if __name__ == "__main__":
376-
main()

matrix/cluster/ray_cluster.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ def start_grafana(self, force: bool):
149149
num_gpus=0,
150150
max_restarts=3, # Allow 3 automatic retries
151151
max_task_retries=-1,
152-
).remote(cluster_info)
152+
).remote(
153+
cluster_info.temp_dir,
154+
cluster_info.prometheus_port,
155+
cluster_info.grafana_port,
156+
)
153157
ray.get(actor.start.remote())
154158
return "Successfully started Grafana dashboard"
155159

@@ -159,6 +163,7 @@ def start(
159163
slurm: tp.Dict[str, tp.Union[str, int]] | None,
160164
local: tp.Dict[str, tp.Union[str, int]] | None,
161165
enable_grafana: bool = False,
166+
force_new_head: bool = False,
162167
):
163168
"""
164169
Starts a Ray cluster on Slurm.
@@ -168,19 +173,24 @@ def start(
168173
169174
Args:
170175
add_workers (int): The number of worker nodes to start.
171-
requirements (dict): Slurm resource requirements for worker nodes.
172-
e.g., {'qos': '...', 'partition': '...', 'gpus-per-node': 8}.
173-
head_requirements (dict): optional to specify head requirements when launching head.
174-
executor (str): Slurm executor to use (default: "slurm").
176+
slurm (dict, optional): resources requirements for slurm cluster.
177+
e.g., {'qos': '...', 'partition': '...', 'gpus-per-node': 8}.
178+
local (dict, optional): resources requirements for local cluster.
175179
enable_grafana (bool): Whether to start Prometheus and Grafana
176180
for monitoring (default: True).
181+
force_new_head (bool): force to remove head.json if haven't run 'matrix stop_cluster'.
177182
"""
178183
common_params = {"account", "partition", "qos", "exclusive"}
179184
start_wait_time_seconds = 60
180185
worker_wait_timeout_seconds = 60
181186
requirements = slurm or local or {}
182187
executor = "slurm" if slurm else "local"
183188

189+
if force_new_head:
190+
# remove existing head.json
191+
if self._cluster_json.exists():
192+
self._cluster_json.unlink()
193+
184194
if self._cluster_json.exists():
185195
print(f"Adding workers to existing cluster:\n{self.cluster_info()}")
186196
# todo: check the cluser is alive

matrix/cluster/ray_dashboard_job.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import ray
1717

18-
from matrix.common.cluster_info import ClusterInfo
19-
2018
# Configure logging
2119
logging.basicConfig(
2220
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -38,14 +36,14 @@ class RayDashboardJob:
3836
properly when the actor is killed.
3937
"""
4038

41-
def __init__(self, cluster_info: ClusterInfo):
39+
def __init__(self, temp_dir: str, prometheus_port: int, grafana_port: int):
4240
"""
4341
Initialize the RayDashboardJob actor.
4442
"""
4543
self.head_env = os.environ.copy()
46-
self.temp_dir = cluster_info.temp_dir
47-
self.prometheus_port = cluster_info.prometheus_port
48-
self.grafana_port = cluster_info.grafana_port
44+
self.temp_dir = temp_dir
45+
self.prometheus_port = prometheus_port
46+
self.grafana_port = grafana_port
4947
self.processes: List[subprocess.Popen[str]] = []
5048
self.monitor_thread: threading.Thread | None = None
5149
self.should_run = True

0 commit comments

Comments
 (0)