@@ -74,6 +74,7 @@ def __init__( # type: ignore[no-untyped-def]
7474 worker_module = "deprecated" ,
7575 worker_class = "distributed.Nanny" ,
7676 remote_python = None ,
77+ forward_config = True ,
7778 loop = None ,
7879 name = None ,
7980 ):
@@ -92,6 +93,7 @@ def __init__( # type: ignore[no-untyped-def]
9293 self .kwargs = copy .copy (kwargs )
9394 self .name = name
9495 self .remote_python = remote_python
96+ self .forward_config = forward_config
9597 if kwargs .get ("nprocs" ) is not None and kwargs .get ("n_workers" ) is not None :
9698 raise ValueError (
9799 "Both nprocs and n_workers were specified. Use n_workers only."
@@ -135,21 +137,24 @@ async def start(self):
135137
136138 self .connection = await asyncssh .connect (self .address , ** self .connect_options )
137139
138- result = await self .connection .run ("uname" )
139- if result .exit_status == 0 :
140- set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"' .format (
141- dask .config .serialize (dask .config .global_config )
142- )
143- else :
144- result = await self .connection .run ("cmd /c ver" )
140+ if self .forward_config :
141+ result = await self .connection .run ("uname" )
145142 if result .exit_status == 0 :
146- set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&" .format (
143+ set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"' .format (
147144 dask .config .serialize (dask .config .global_config )
148145 )
149146 else :
150- raise Exception (
151- "Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
152- )
147+ result = await self .connection .run ("cmd /c ver" )
148+ if result .exit_status == 0 :
149+ set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&" .format (
150+ dask .config .serialize (dask .config .global_config )
151+ )
152+ else :
153+ raise Exception (
154+ "Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
155+ )
156+ else :
157+ set_env = ""
153158
154159 if not self .remote_python :
155160 self .remote_python = sys .executable
@@ -175,7 +180,7 @@ async def start(self):
175180 }
176181 ),
177182 ]
178- )
183+ ). strip ()
179184
180185 self .proc = await self .connection .create_process (cmd )
181186
@@ -214,13 +219,15 @@ def __init__(
214219 connect_options : dict ,
215220 kwargs : dict ,
216221 remote_python : str | None = None ,
222+ forward_config : bool = True ,
217223 ):
218224 super ().__init__ ()
219225
220226 self .address = address
221227 self .kwargs = kwargs
222228 self .connect_options = connect_options
223229 self .remote_python = remote_python or sys .executable
230+ self .forward_config = forward_config
224231
225232 async def start (self ):
226233 try :
@@ -235,21 +242,24 @@ async def start(self):
235242
236243 self .connection = await asyncssh .connect (self .address , ** self .connect_options )
237244
238- result = await self .connection .run ("uname" )
239- if result .exit_status == 0 :
240- set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"' .format (
241- dask .config .serialize (dask .config .global_config )
242- )
243- else :
244- result = await self .connection .run ("cmd /c ver" )
245+ if self .forward_config :
246+ result = await self .connection .run ("uname" )
245247 if result .exit_status == 0 :
246- set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&" .format (
248+ set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"' .format (
247249 dask .config .serialize (dask .config .global_config )
248250 )
249251 else :
250- raise Exception (
251- "Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
252- )
252+ result = await self .connection .run ("cmd /c ver" )
253+ if result .exit_status == 0 :
254+ set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&" .format (
255+ dask .config .serialize (dask .config .global_config )
256+ )
257+ else :
258+ raise Exception (
259+ "Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
260+ )
261+ else :
262+ set_env = ""
253263
254264 cmd = " " .join (
255265 [
@@ -260,7 +270,7 @@ async def start(self):
260270 "--spec" ,
261271 "'%s'" % dumps ({"cls" : "distributed.Scheduler" , "opts" : self .kwargs }),
262272 ]
263- )
273+ ). strip ()
264274 self .proc = await self .connection .create_process (cmd )
265275
266276 # We watch stderr in order to get the address, then we return
@@ -304,6 +314,7 @@ def SSHCluster(
304314 worker_module : str = "deprecated" ,
305315 worker_class : str = "distributed.Nanny" ,
306316 remote_python : str | list [str ] | None = None ,
317+ forward_config : bool = True ,
307318 ** kwargs : Any ,
308319) -> SpecCluster :
309320 """Deploy a Dask cluster using SSH
@@ -344,6 +355,8 @@ def SSHCluster(
344355 The python class to use to create the worker(s).
345356 remote_python
346357 Path to Python on remote nodes.
358+ forward_config
359+ Forward the local Dask configuration to the remote nodes.
347360
348361 Examples
349362 --------
@@ -443,6 +456,7 @@ def SSHCluster(
443456 "remote_python" : (
444457 remote_python [0 ] if isinstance (remote_python , list ) else remote_python
445458 ),
459+ "forward_config" : forward_config ,
446460 },
447461 }
448462 workers = {
@@ -462,6 +476,7 @@ def SSHCluster(
462476 if isinstance (remote_python , list )
463477 else remote_python
464478 ),
479+ "forward_config" : forward_config ,
465480 },
466481 }
467482 for i , host in enumerate (hosts [1 :])
0 commit comments