1515from typing import TYPE_CHECKING , Any , Callable , ClassVar
1616
1717from dask .typing import Key
18- from dask .utils import funcname , tmpfile
18+ from dask .utils import _deprecated_kwarg , funcname , tmpfile
1919
2020from distributed .protocol .pickle import dumps
2121
@@ -896,36 +896,46 @@ async def setup(self, nanny):
896896 nanny .env .update (self .environ )
897897
898898
899- class UploadDirectory (NannyPlugin ):
900- """A NannyPlugin to upload a local file to workers.
899+ UPLOAD_DIRECTORY_MODES = ["all" , "scheduler" , "workers" ]
900+
901+
902+ class UploadDirectory (SchedulerPlugin ):
903+ """Scheduler to upload a local directory to the cluster.
901904
902905 Parameters
903906 ----------
904- path: str
905- A path to the directory to upload
907+ path:
908+ Path to the directory to upload
909+ scheduler:
910+ Whether to upload the directory to the scheduler
906911
907912 Examples
908913 --------
909914 >>> from distributed.diagnostics.plugin import UploadDirectory
910- >>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True ) # doctest: +SKIP
915+ >>> client.register_plugin(UploadDirectory("/path/to/directory")) # doctest: +SKIP
911916 """
912917
918+ @_deprecated_kwarg ("restart" , "restart_workers" )
913919 def __init__ (
914920 self ,
915921 path ,
916- restart = False ,
922+ restart_workers = False ,
917923 update_path = False ,
918924 skip_words = (".git" , ".github" , ".pytest_cache" , "tests" , "docs" ),
919925 skip = (lambda fn : os .path .splitext (fn )[1 ] == ".pyc" ,),
926+ mode = "workers" ,
920927 ):
921- """
922- Initialize the plugin by reading in the data from the given file.
923- """
924928 path = os .path .expanduser (path )
925929 self .path = os .path .split (path )[- 1 ]
926- self .restart = restart
930+ self .restart_workers = restart_workers
927931 self .update_path = update_path
928932
933+ if mode not in UPLOAD_DIRECTORY_MODES :
934+ raise ValueError (
935+ f"{ mode = } not supported, expected one of { UPLOAD_DIRECTORY_MODES } "
936+ )
937+ self .mode = mode
938+
929939 self .name = "upload-directory-" + os .path .split (path )[- 1 ]
930940
931941 with tmpfile (extension = "zip" ) as fn :
@@ -944,26 +954,67 @@ def __init__(
944954 )
945955 z .write (filename , archive_name )
946956
947- with open (fn , "rb" ) as f :
957+ with open (fn , mode = "rb" ) as f :
948958 self .data = f .read ()
949959
950- async def setup (self , nanny ):
951- fn = os .path .join (nanny .local_directory , f"tmp-{ uuid .uuid4 ()} .zip" )
952- with open (fn , "wb" ) as f :
953- f .write (self .data )
960+ async def start (self , scheduler ):
961+ from distributed .core import clean_exception
962+ from distributed .protocol .serialize import Serialized , deserialize
963+
964+ if self .mode in ("all" , "scheduler" ):
965+ _extract_data (
966+ scheduler .local_directory , self .path , self .data , self .update_graph
967+ )
968+
969+ if self .mode in ("all" , "workers" ):
970+ nanny_plugin = _UploadDirectoryNannyPlugin (
971+ self .path , self .data , self .restart_workers , self .update_path , self .name
972+ )
973+ responses = await scheduler .register_nanny_plugin (
974+ comm = None ,
975+ plugin = dumps (nanny_plugin ),
976+ name = self .name ,
977+ idempotent = False ,
978+ )
979+
980+ for response in responses .values ():
981+ if response ["status" ] == "error" :
982+ response = {
983+ k : deserialize (v .header , v .frames )
984+ for k , v in response .items ()
985+ if isinstance (v , Serialized )
986+ }
987+ _ , exc , tb = clean_exception (** response )
988+ raise exc .with_traceback (tb )
989+
990+
991+ class _UploadDirectoryNannyPlugin (NannyPlugin ):
992+ def __init__ (self , path , data , restart , update_path , name ):
993+ self .path = path
994+ self .data = data
995+ self .name = name
996+ self .restart = restart
997+ self .update_path = update_path
998+
999+ def setup (self , nanny ):
1000+ _extract_data (nanny .local_directory , self .path , self .data , self .update_path )
1001+
1002+
1003+ def _extract_data (base_path , path , data , update_path ):
1004+ with tmpfile (extension = "zip" ) as fn :
1005+ with open (fn , mode = "wb" ) as f :
1006+ f .write (data )
9541007
9551008 import zipfile
9561009
9571010 with zipfile .ZipFile (fn ) as z :
958- z .extractall (path = nanny . local_directory )
1011+ z .extractall (path = base_path )
9591012
960- if self . update_path :
961- path = os .path .join (nanny . local_directory , self . path )
1013+ if update_path :
1014+ path = os .path .join (base_path , path )
9621015 if path not in sys .path :
9631016 sys .path .insert (0 , path )
9641017
965- os .remove (fn )
966-
9671018
9681019class forward_stream :
9691020 def __init__ (self , stream , worker ):
0 commit comments