@@ -107,7 +107,7 @@ def test_set_up_dask_client_default() -> None:
107107 # Shut down a client if it exists
108108 try :
109109 client = get_client ()
110- client .shutdown ()
110+ client .shutdown () # type: ignore[no-untyped-call]
111111 except ValueError :
112112 pass
113113 finally :
@@ -116,43 +116,77 @@ def test_set_up_dask_client_default() -> None:
116116 client = get_client ()
117117
118118 set_up_dask_client ()
119- client = get_client ()
120- assert isinstance (client .cluster , LocalCluster )
121- assert client .cluster .name == "pseudopeople_dask_cluster"
122- workers = client .scheduler_info ()["workers" ] # type: ignore[no-untyped-call]
123- assert len (workers ) == CPU_COUNT
124- assert all (worker ["nthreads" ] == 1 for worker in workers .values ())
119+
125120 if is_on_slurm ():
126121 try :
127122 available_memory = float (os .environ ["SLURM_MEM_PER_NODE" ]) / 1024
128123 except KeyError :
129124 raise RuntimeError (
130- "You are on Slurm but SLURM_MEM_PER_NODE is not set. "
131- "It is likely that you are SSHed onto a node (perhaps using VSCode). "
125+ "NOTE: This RuntimeError is expected if you are using VSCode on the cluster!\n \n "
126+ "You are on Slurm but SLURM_MEM_PER_NODE is not set; it is likely "
127+ "that you are SSHed onto a node (perhaps using VSCode?). "
132128 "In this case, dask will assign the total memory of the node to the "
133129 "cluster instead of the allocated memory from the srun call. "
134130 "Pseudopeople should only be used on Slurm directly on the node "
135131 "assigned via an srun (both for pytests as well as actual work)."
136132 )
137133 else :
138- available_memory = psutil .virtual_memory ().total / (1024 ** 3 )
139- assert np .isclose (sum (worker ["memory_limit" ] / 1024 ** 3 for worker in workers .values ()), available_memory , rtol = 0.01 )
134+ available_memory = psutil .virtual_memory ().total / (1024 ** 3 )
135+
136+ _check_cluster_attrs (
137+ cluster_name = "pseudopeople_dask_cluster" ,
138+ memory_limit = available_memory ,
139+ n_workers = CPU_COUNT ,
140+ threads_per_worker = 1 ,
141+ )
140142
141143
142- def test_set_up_dask_client_custom () -> None :
144+ def test_set_up_dask_client_existing_cluster () -> None :
145+ cluster_name = "custom"
143146 memory_limit = 1 # gb
144147 n_workers = 3
148+ threads_per_worker = 2
149+
150+ # Manually create a cluster
145151 cluster = LocalCluster ( # type: ignore[no-untyped-call]
146- name = "custom" ,
152+ name = cluster_name ,
147153 n_workers = n_workers ,
148- threads_per_worker = 2 ,
154+ threads_per_worker = threads_per_worker ,
149155 memory_limit = memory_limit * 1024 ** 3 ,
150156 )
151- client = cluster .get_client () # type: ignore[no-untyped-call]
157+ cluster .get_client () # type: ignore[no-untyped-call]
158+ _check_cluster_attrs (
159+ cluster_name = cluster_name ,
160+ memory_limit = memory_limit * n_workers ,
161+ n_workers = n_workers ,
162+ threads_per_worker = threads_per_worker
163+ )
164+
165+ # Call the dask client setup function
152166 set_up_dask_client ()
153- client = get_client ()
154- assert client .cluster .name == "custom"
155- workers = client .scheduler_info ()["workers" ]
156- assert len (workers ) == 3
157- assert all (worker ["nthreads" ] == 2 for worker in workers .values ())
158- assert sum (worker ["memory_limit" ] / 1024 ** 3 for worker in workers .values ()) == memory_limit * n_workers
167+
168+ # Make sure that the cluster hasn't been changed
169+ assert get_client ().cluster == cluster
170+ _check_cluster_attrs (
171+ cluster_name = cluster_name ,
172+ memory_limit = memory_limit * n_workers ,
173+ n_workers = n_workers ,
174+ threads_per_worker = threads_per_worker
175+ )
176+
177+ ####################
178+ # Helper Functions #
179+ ####################
180+
181+ def _check_cluster_attrs (cluster_name : str , memory_limit : int | float , n_workers : int , threads_per_worker : int ) -> None :
182+ cluster = get_client ().cluster
183+ assert isinstance (cluster , LocalCluster )
184+ assert cluster .name == cluster_name
185+ workers = cluster .scheduler_info ["workers" ]
186+ assert len (workers ) == n_workers
187+ assert all (worker ["nthreads" ] == threads_per_worker for worker in workers .values ())
188+ assert np .isclose (
189+ sum (worker ["memory_limit" ] / 1024 ** 3 for worker in workers .values ()),
190+ memory_limit ,
191+ rtol = 0.01 ,
192+ )
0 commit comments