@@ -8,8 +8,10 @@ import (
88 "os/exec"
99 "strconv"
1010 "strings"
11+ "sync"
1112
1213 "github.com/ipfs/go-cid"
14+ logging "github.com/ipfs/go-log/v2"
1315 "github.com/samber/lo"
1416 "golang.org/x/xerrors"
1517
@@ -31,26 +33,43 @@ func WithLogCtx(ctx context.Context, kvs ...any) context.Context {
3133 return context .WithValue (ctx , logCtxKey , kvs )
3234}
3335
36+ var logger = logging .Logger ("ffiselect" )
37+
3438var IsTest = false
3539var IsCuda = build .IsOpencl != "1"
3640
3741// Get all devices from ffi
38- var ch chan string
39-
42+ var gpuSlots []byte
43+ var gpuSlotsMx sync.Mutex
44+
45+ // getDeviceOrdinal returns the ordinal of the GPU with the least workload.
46+ func getDeviceOrdinal () int {
47+ gpuSlotsMx .Lock ()
48+ defer gpuSlotsMx .Unlock ()
49+ max , maxIdx := byte (0 ), 0
50+ for i , w := range gpuSlots {
51+ if w > max {
52+ max , maxIdx = w , i
53+ }
54+ }
55+ if max == 0 {
56+ logger .Errorf ("no GPUs available. Something went wrong in the scheduler." )
57+ return - 1
58+ }
59+ gpuSlots [maxIdx ]--
60+ return maxIdx
61+ }
4062func init () {
4163 devices , err := ffi .GetGPUDevices ()
4264 if err != nil {
4365 panic (err )
4466 }
4567 if len (devices ) == 0 {
46- ch = make (chan string , 1 )
47- ch <- "0"
68+ gpuSlots = []byte {1 }
4869 } else {
49- nSlots := len (devices ) * resources .GpuOverprovisionFactor
50-
51- ch = make (chan string , nSlots )
52- for i := 0 ; i < nSlots ; i ++ {
53- ch <- strconv .Itoa (i / resources .GpuOverprovisionFactor )
70+ gpuSlots = make ([]byte , len (devices ))
71+ for i := range gpuSlots {
72+ gpuSlots [i ] = byte (resources .GpuOverprovisionFactor )
5473 }
5574 }
5675}
@@ -76,11 +95,17 @@ func call(ctx context.Context, body []byte) (io.ReadCloser, error) {
7695 }
7796
7897 // get dOrdinal
79- dOrdinal := <- ch
98+ dOrdinal := getDeviceOrdinal ()
8099 defer func () {
81- ch <- dOrdinal
100+ gpuSlotsMx .Lock ()
101+ gpuSlots [dOrdinal ]++
102+ gpuSlotsMx .Unlock ()
82103 }()
83104
105+ if dOrdinal == - 1 {
106+ return nil , xerrors .Errorf ("no GPUs available. Something went wrong in the scheduler." )
107+ }
108+
84109 p , err := os .Executable ()
85110 if err != nil {
86111 return nil , err
@@ -92,10 +117,11 @@ func call(ctx context.Context, body []byte) (io.ReadCloser, error) {
92117 // Set Visible Devices for CUDA and OpenCL
93118 cmd .Env = append (os .Environ (),
94119 func (isCuda bool ) string {
120+ ordinal := strconv .Itoa (dOrdinal )
95121 if isCuda {
96- return "CUDA_VISIBLE_DEVICES=" + dOrdinal
122+ return "CUDA_VISIBLE_DEVICES=" + ordinal
97123 }
98- return "GPU_DEVICE_ORDINAL=" + dOrdinal
124+ return "GPU_DEVICE_ORDINAL=" + ordinal
99125 }(IsCuda ))
100126 tmpDir , err := os .MkdirTemp ("" , "rust-fil-proofs" )
101127 if err != nil {
0 commit comments