Skip to content

Commit 76d63b1

Browse files
Add keyword to allow disabling config forwarding in SSHCluster
1 parent 49f5e74 commit 76d63b1

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

distributed/deploy/ssh.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__( # type: ignore[no-untyped-def]
7474
worker_module="deprecated",
7575
worker_class="distributed.Nanny",
7676
remote_python=None,
77+
forward_config=True,
7778
loop=None,
7879
name=None,
7980
):
@@ -92,6 +93,7 @@ def __init__( # type: ignore[no-untyped-def]
9293
self.kwargs = copy.copy(kwargs)
9394
self.name = name
9495
self.remote_python = remote_python
96+
self.forward_config = forward_config
9597
if kwargs.get("nprocs") is not None and kwargs.get("n_workers") is not None:
9698
raise ValueError(
9799
"Both nprocs and n_workers were specified. Use n_workers only."
@@ -135,21 +137,24 @@ async def start(self):
135137

136138
self.connection = await asyncssh.connect(self.address, **self.connect_options)
137139

138-
result = await self.connection.run("uname")
139-
if result.exit_status == 0:
140-
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
141-
dask.config.serialize(dask.config.global_config)
142-
)
143-
else:
144-
result = await self.connection.run("cmd /c ver")
140+
if self.forward_config:
141+
result = await self.connection.run("uname")
145142
if result.exit_status == 0:
146-
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
143+
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
147144
dask.config.serialize(dask.config.global_config)
148145
)
149146
else:
150-
raise Exception(
151-
"Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
152-
)
147+
result = await self.connection.run("cmd /c ver")
148+
if result.exit_status == 0:
149+
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
150+
dask.config.serialize(dask.config.global_config)
151+
)
152+
else:
153+
raise Exception(
154+
"Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
155+
)
156+
else:
157+
set_env = ""
153158

154159
if not self.remote_python:
155160
self.remote_python = sys.executable
@@ -175,7 +180,7 @@ async def start(self):
175180
}
176181
),
177182
]
178-
)
183+
).strip()
179184

180185
self.proc = await self.connection.create_process(cmd)
181186

@@ -214,13 +219,15 @@ def __init__(
214219
connect_options: dict,
215220
kwargs: dict,
216221
remote_python: str | None = None,
222+
forward_config: bool = True,
217223
):
218224
super().__init__()
219225

220226
self.address = address
221227
self.kwargs = kwargs
222228
self.connect_options = connect_options
223229
self.remote_python = remote_python or sys.executable
230+
self.forward_config = forward_config
224231

225232
async def start(self):
226233
try:
@@ -235,21 +242,24 @@ async def start(self):
235242

236243
self.connection = await asyncssh.connect(self.address, **self.connect_options)
237244

238-
result = await self.connection.run("uname")
239-
if result.exit_status == 0:
240-
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
241-
dask.config.serialize(dask.config.global_config)
242-
)
243-
else:
244-
result = await self.connection.run("cmd /c ver")
245+
if self.forward_config:
246+
result = await self.connection.run("uname")
245247
if result.exit_status == 0:
246-
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
248+
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
247249
dask.config.serialize(dask.config.global_config)
248250
)
249251
else:
250-
raise Exception(
251-
"Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
252-
)
252+
result = await self.connection.run("cmd /c ver")
253+
if result.exit_status == 0:
254+
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
255+
dask.config.serialize(dask.config.global_config)
256+
)
257+
else:
258+
raise Exception(
259+
"Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
260+
)
261+
else:
262+
set_env = ""
253263

254264
cmd = " ".join(
255265
[
@@ -260,7 +270,7 @@ async def start(self):
260270
"--spec",
261271
"'%s'" % dumps({"cls": "distributed.Scheduler", "opts": self.kwargs}),
262272
]
263-
)
273+
).strip()
264274
self.proc = await self.connection.create_process(cmd)
265275

266276
# We watch stderr in order to get the address, then we return
@@ -304,6 +314,7 @@ def SSHCluster(
304314
worker_module: str = "deprecated",
305315
worker_class: str = "distributed.Nanny",
306316
remote_python: str | list[str] | None = None,
317+
forward_config: bool = True,
307318
**kwargs: Any,
308319
) -> SpecCluster:
309320
"""Deploy a Dask cluster using SSH
@@ -344,6 +355,8 @@ def SSHCluster(
344355
The python class to use to create the worker(s).
345356
remote_python
346357
Path to Python on remote nodes.
358+
forward_config
359+
Forward the local Dask configuration to the remote nodes.
347360
348361
Examples
349362
--------
@@ -443,6 +456,7 @@ def SSHCluster(
443456
"remote_python": (
444457
remote_python[0] if isinstance(remote_python, list) else remote_python
445458
),
459+
"forward_config": forward_config,
446460
},
447461
}
448462
workers = {
@@ -462,6 +476,7 @@ def SSHCluster(
462476
if isinstance(remote_python, list)
463477
else remote_python
464478
),
479+
"forward_config": forward_config,
465480
},
466481
}
467482
for i, host in enumerate(hosts[1:])

0 commit comments

Comments
 (0)