Skip to content

Commit 6cbfaec

Browse files
Merge pull request #174 from CliMA/ne/addprocs
Add default function to add workers
2 parents c747b37 + beb317b commit 6cbfaec

File tree

6 files changed

+119
-12
lines changed

6 files changed

+119
-12
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ClimaCalibrate.postprocess_g_ensemble
1111

1212
## Worker Interface
1313
```@docs
14+
ClimaCalibrate.add_workers
1415
ClimaCalibrate.WorkerBackend
1516
ClimaCalibrate.SlurmManager
1617
ClimaCalibrate.PBSManager

src/backends.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Used for CliMA's private GPU server.
3434
"""
3535
struct ClimaGPUBackend <: SlurmBackend end
3636

37+
struct GCPBackend <: SlurmBackend end
38+
3739
"""
3840
DerechoBackend
3941
@@ -61,6 +63,8 @@ function get_backend()
6163
(r"^clima.gps.caltech.edu$", ClimaGPUBackend),
6264
(r"^login[1-4].cm.cluster$", CaltechHPCBackend),
6365
(r"^hpc-(\d\d)-(\d\d).cm.cluster$", CaltechHPCBackend),
66+
(r"^hpc\d+-slurm-login-\d+$", GCPBackend),
67+
(r"^hpc\d+-a\d+nodeset-\d+$", GCPBackend),
6468
(r"derecho([1-8])$", DerechoBackend),
6569
(r"deg(\d\d\d\d)$", DerechoBackend), # This should be more specific
6670
]
@@ -434,3 +438,9 @@ function model_run(
434438
end
435439
return job_id
436440
end
441+
442+
backend_worker_kwargs(::Type{DerechoBackend}) = (; q = "main", A = "UCIT0011")
443+
444+
backend_worker_kwargs(::Type{GCPBackend}) = (; partition = "a3")
445+
446+
backend_worker_kwargs(::Type{<:AbstractBackend}) = (;)

src/workers.jl

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Distributed
22
using Logging
33

4-
export SlurmManager, PBSManager, default_worker_pool, set_worker_loggers
4+
export add_workers,
5+
SlurmManager, PBSManager, default_worker_pool, set_worker_loggers
56

67
# Set the time limit for the Julia worker to be contacted by the main process, default = "60.0s"
78
# https://docs.julialang.org/en/v1/manual/environment-variables/#JULIA_WORKER_TIMEOUT
@@ -444,3 +445,103 @@ function set_worker_loggers(workers = workers())
444445
end
445446
end
446447
end
448+
449+
450+
function is_pbs_available()
451+
return all([
452+
!isnothing(Sys.which("qstat")),
453+
!isnothing(Sys.which("pbsnodes")),
454+
!isnothing(Sys.which("qsub")),
455+
])
456+
end
457+
458+
459+
function is_slurm_available()
460+
return all([
461+
!isnothing(Sys.which("sinfo")),
462+
!isnothing(Sys.which("srun")),
463+
!isnothing(Sys.which("sbatch")),
464+
])
465+
end
466+
467+
function is_cluster_environment()
468+
return is_pbs_available() || is_slurm_available()
469+
end
470+
471+
const DEFAULT_WALLTIME = 60
472+
473+
default_cpu_kwargs(::SlurmManager) = (;
474+
cpus_per_task = 1,
475+
time = format_slurm_time(DEFAULT_WALLTIME),
476+
backend_worker_kwargs(get_backend())...,
477+
)
478+
default_cpu_kwargs(::PBSManager) = (;
479+
l_select = "ncpus=1",
480+
l_walltime = format_pbs_time(DEFAULT_WALLTIME),
481+
backend_worker_kwargs(get_backend())...,
482+
)
483+
484+
default_gpu_kwargs(::SlurmManager) = (;
485+
gpus_per_task = 1,
486+
cpus_per_task = 4,
487+
time = format_slurm_time(DEFAULT_WALLTIME),
488+
backend_worker_kwargs(get_backend())...,
489+
)
490+
default_gpu_kwargs(::PBSManager) = (;
491+
l_select = "ngpus=1:ncpus=4",
492+
l_walltime = format_pbs_time(DEFAULT_WALLTIME),
493+
backend_worker_kwargs(get_backend())...,
494+
)
495+
496+
function get_manager(cluster = :auto, nworkers = 1)
497+
if cluster == :slurm || (cluster == :auto && is_slurm_available())
498+
SlurmManager(nworkers)
499+
elseif cluster == :pbs || (cluster == :auto && is_pbs_available())
500+
PBSManager(nworkers)
501+
else
502+
error(
503+
"Unknown cluster type: $cluster. Valid options are :auto, :pbs, :slurm, or :local",
504+
)
505+
end
506+
end
507+
508+
"""
509+
add_workers(
510+
nworkers;
511+
device = :gpu,
512+
cluster = :auto,
513+
kwargs...
514+
)
515+
516+
Add `nworkers` worker processes to the current Julia session, automatically detecting and configuring for the available computing environment.
517+
518+
# Arguments
519+
- `nworkers::Int`: The number of worker processes to add.
520+
- `device::Symbol = :gpu`: The target compute device type, either `:gpu` (1 GPU, 4 CPU cores) or `:cpu` (1 CPU core).
521+
- `cluster::Symbol = :auto`: The cluster management system to use. Options:
522+
* `:auto`: Auto-detect available cluster environment (SLURM, PBS, or local)
523+
* `:slurm`: Force use of SLURM scheduler
524+
* `:pbs`: Force use of PBS scheduler
525+
* `:local`: Force use of local processing (standard `addprocs`)
526+
- `kwargs`: Other kwargs can be passed directly through to `addprocs`.
527+
"""
528+
function add_workers(nworkers::Int; device = :gpu, cluster = :auto, kwargs...)
529+
if cluster == :local || (cluster == :auto && !is_cluster_environment())
530+
# Use standard addprocs for local computation
531+
@info "Using local processing mode, adding $nworkers worker$(nworkers == 1 ? "" : "s")"
532+
return addprocs(nworkers)
533+
else
534+
# Select the manager based on environment or explicit selection
535+
manager = get_manager(cluster, nworkers)
536+
@info "Using $(nameof(typeof(manager))) to add $nworkers workers"
537+
538+
default_kwargs =
539+
device == :gpu ? default_gpu_kwargs(manager) :
540+
default_cpu_kwargs(manager)
541+
542+
# Merge the default kwargs with the user-provided kwargs, user kwargs take precedence
543+
merged_kwargs = merge(default_kwargs, kwargs)
544+
545+
return addprocs(manager; merged_kwargs...)
546+
end
547+
end

test/hpc_backend.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ if backend == DerechoBackend
1717
end
1818

1919
original_model_interface = model_interface
20-
interruption_model_interface, io = mktemp()
20+
interruption_model_interface, io = mktemp(@__DIR__)
2121
model_interface_str = """
2222
import ClimaCalibrate
2323
ClimaCalibrate.forward_model(iter, member) = member == 1 && exit()

test/pbs_manager_unit_tests.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
using Test, ClimaCalibrate, Distributed, Logging
22

33
@testset "PBSManager Unit Tests" begin
4-
p = addprocs(
5-
PBSManager(1),
6-
q = "main",
7-
A = "UCIT0011",
8-
l_select = "ngpus=1",
9-
l_walltime = "00:05:00",
10-
)
4+
@test ClimaCalibrate.get_manager() == PBSManager(1)
5+
p = add_workers(1; l_walltime = "00:05:00")
116
@test nprocs() == length(p) + 1
127
@test workers() == p
138
@test remotecall_fetch(myid, 2) == 2

test/slurm_manager_unit_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using Test, ClimaCalibrate, Distributed, Logging
22

33
@testset "SlurmManager Unit Tests" begin
4-
out_file = "my_slurm_job.out"
5-
p = addprocs(SlurmManager(1); o = out_file)
4+
@test ClimaCalibrate.get_manager() == SlurmManager(1)
5+
out_file = tempname()
6+
p = add_workers(1; device = :cpu, o = out_file)
67
@test nprocs() == 2
78
@test workers() == p
89
@test fetch(@spawnat :any myid()) == p[1]
@@ -16,7 +17,6 @@ using Test, ClimaCalibrate, Distributed, Logging
1617
@test workers() == [1]
1718
# Check output file creation
1819
@test isfile(out_file)
19-
rm(out_file)
2020

2121
# Test incorrect generic arguments
2222
@test_throws TaskFailedException p = addprocs(SlurmManager(1), time = "w")

0 commit comments

Comments
 (0)