Skip to content

Commit 618d2a3

Browse files
committed
support multi runner with same model
1 parent f2cf97d commit 618d2a3

File tree

2 files changed

+94
-53
lines changed

2 files changed

+94
-53
lines changed

server/routes.go

+20-18
Original file line numberDiff line numberDiff line change
@@ -1142,26 +1142,28 @@ func streamResponse(c *gin.Context, ch chan any) {
11421142
func (s *Server) ProcessHandler(c *gin.Context) {
11431143
models := []api.ModelResponse{}
11441144

1145-
for _, v := range s.sched.loaded {
1146-
model := v.model
1147-
modelDetails := api.ModelDetails{
1148-
Format: model.Config.ModelFormat,
1149-
Family: model.Config.ModelFamily,
1150-
Families: model.Config.ModelFamilies,
1151-
ParameterSize: model.Config.ModelType,
1152-
QuantizationLevel: model.Config.FileType,
1153-
}
1145+
for _, runners := range s.sched.loaded {
1146+
for _, v := range runners {
1147+
model := v.model
1148+
modelDetails := api.ModelDetails{
1149+
Format: model.Config.ModelFormat,
1150+
Family: model.Config.ModelFamily,
1151+
Families: model.Config.ModelFamilies,
1152+
ParameterSize: model.Config.ModelType,
1153+
QuantizationLevel: model.Config.FileType,
1154+
}
11541155

1155-
mr := api.ModelResponse{
1156-
Model: model.ShortName,
1157-
Name: model.ShortName,
1158-
Size: int64(v.estimatedTotal),
1159-
SizeVRAM: int64(v.estimatedVRAM),
1160-
Digest: model.Digest,
1161-
Details: modelDetails,
1162-
ExpiresAt: v.expiresAt,
1156+
mr := api.ModelResponse{
1157+
Model: model.ShortName,
1158+
Name: model.ShortName,
1159+
Size: int64(v.estimatedTotal),
1160+
SizeVRAM: int64(v.estimatedVRAM),
1161+
Digest: model.Digest,
1162+
Details: modelDetails,
1163+
ExpiresAt: v.expiresAt,
1164+
}
1165+
models = append(models, mr)
11631166
}
1164-
models = append(models, mr)
11651167
}
11661168

11671169
c.JSON(http.StatusOK, api.ListResponse{Models: models})

server/sched.go

+74-35
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ type LlmRequest struct {
3131

3232
type Scheduler struct {
3333
pendingReqCh chan *LlmRequest
34-
finishedReqCh chan *LlmRequest
34+
finishedReqCh chan *runnerRef
3535
expiredCh chan *runnerRef
3636
unloadedCh chan interface{}
3737

38-
loaded map[string]*runnerRef
38+
loaded map[string][]*runnerRef
3939
loadedMu sync.Mutex
4040

4141
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
@@ -48,10 +48,10 @@ var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending re
4848
func InitScheduler(ctx context.Context) *Scheduler {
4949
sched := &Scheduler{
5050
pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
51-
finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
51+
finishedReqCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
5252
expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
5353
unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests),
54-
loaded: make(map[string]*runnerRef),
54+
loaded: make(map[string][]*runnerRef),
5555
newServerFn: llm.NewLlamaServer,
5656
getGpuFn: gpu.GetGPUInfo,
5757
}
@@ -114,9 +114,21 @@ func (s *Scheduler) processPending(ctx context.Context) {
114114
for {
115115
var runnerToExpire *runnerRef
116116
s.loadedMu.Lock()
117-
runner := s.loaded[pending.model.ModelPath]
118-
loadedCount := len(s.loaded)
117+
runners := s.loaded[pending.model.ModelPath]
118+
loadedCount := 0
119+
for _, runnerList := range s.loaded {
120+
loadedCount += len(runnerList)
121+
}
119122
s.loadedMu.Unlock()
123+
var runner *runnerRef = nil
124+
if len(runners) > 0 {
125+
for _, r := range runners {
126+
if !r.isAtCapacity() {
127+
runner = r
128+
break
129+
}
130+
}
131+
}
120132
if runner != nil {
121133
if runner.needsReload(ctx, pending) {
122134
runnerToExpire = runner
@@ -215,12 +227,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
215227
case <-ctx.Done():
216228
slog.Debug("shutting down scheduler completed loop")
217229
return
218-
case finished := <-s.finishedReqCh:
230+
case finishedRunner := <-s.finishedReqCh:
219231
s.loadedMu.Lock()
220-
runner := s.loaded[finished.model.ModelPath]
232+
runner := finishedRunner
221233
s.loadedMu.Unlock()
222234
if runner == nil {
223-
slog.Error("finished requeset signal received after model unloaded", "modelPath", finished.model.ModelPath)
235+
slog.Error("finished requeset signal received after model unloaded", "modelPath", finishedRunner.model.ModelPath)
224236
continue
225237
}
226238
runner.refMu.Lock()
@@ -274,7 +286,21 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
274286
slog.Debug("got lock to unload", "modelPath", runner.modelPath)
275287
finished := runner.waitForVRAMRecovery()
276288
runner.unload()
277-
delete(s.loaded, runner.modelPath)
289+
290+
modelPath := runner.modelPath
291+
// Find the index of the runner in the slice
292+
for i, r := range s.loaded[modelPath] {
293+
if r == runner {
294+
// Remove the runner from the slice
295+
s.loaded[modelPath] = append(s.loaded[modelPath][:i], s.loaded[modelPath][i+1:]...)
296+
break
297+
}
298+
}
299+
300+
// If the slice is now empty, delete the entry from the map
301+
if len(s.loaded[modelPath]) == 0 {
302+
delete(s.loaded, modelPath)
303+
}
278304
s.loadedMu.Unlock()
279305
slog.Debug("runner released", "modelPath", runner.modelPath)
280306
runner.refMu.Unlock()
@@ -334,8 +360,16 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
334360
runner.refMu.Lock()
335361

336362
s.loadedMu.Lock()
337-
s.loaded[req.model.ModelPath] = runner
338-
slog.Info("loaded runners", "count", len(s.loaded))
363+
if _, ok := s.loaded[req.model.ModelPath]; !ok {
364+
s.loaded[req.model.ModelPath] = make([]*runnerRef, 0)
365+
}
366+
s.loaded[req.model.ModelPath] = append(s.loaded[req.model.ModelPath], runner)
367+
368+
runnerCount := 0
369+
for _, runners := range s.loaded {
370+
runnerCount += len(runners)
371+
}
372+
slog.Info("loaded runners", "count", runnerCount)
339373
s.loadedMu.Unlock()
340374

341375
go func() {
@@ -366,26 +400,29 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
366400
}
367401
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
368402
s.loadedMu.Lock()
369-
for _, r := range s.loaded {
370-
r.refMu.Lock()
371-
gpuIDs := make([]string, 0, len(r.gpus))
372-
if r.llama != nil {
373-
374-
// TODO this should be broken down by GPU instead of assuming uniform spread
375-
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
376-
for _, gpu := range r.gpus {
377-
gpuIDs = append(gpuIDs, gpu.ID)
378-
}
379-
for _, gpu := range allGpus {
380-
if slices.Contains(gpuIDs, gpu.ID) {
381-
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
403+
for _, runners := range s.loaded {
404+
for _, r := range runners {
405+
r.refMu.Lock()
406+
gpuIDs := make([]string, 0, len(r.gpus))
407+
if r.llama != nil {
408+
409+
// TODO this should be broken down by GPU instead of assuming uniform spread
410+
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
411+
for _, gpu := range r.gpus {
412+
gpuIDs = append(gpuIDs, gpu.ID)
413+
}
414+
for _, gpu := range allGpus {
415+
if slices.Contains(gpuIDs, gpu.ID) {
416+
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
417+
}
382418
}
419+
} else {
420+
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
383421
}
384-
} else {
385-
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
422+
r.refMu.Unlock()
386423
}
387-
r.refMu.Unlock()
388424
}
425+
389426
s.loadedMu.Unlock()
390427

391428
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
@@ -583,9 +620,9 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
583620
// findRunnerToUnload finds a runner to unload to make room for a new model
584621
func (s *Scheduler) findRunnerToUnload() *runnerRef {
585622
s.loadedMu.Lock()
586-
runnerList := make([]*runnerRef, 0, len(s.loaded))
587-
for _, r := range s.loaded {
588-
runnerList = append(runnerList, r)
623+
runnerList := make([]*runnerRef, 0)
624+
for _, runners := range s.loaded {
625+
runnerList = append(runnerList, runners...)
589626
}
590627
s.loadedMu.Unlock()
591628

@@ -611,10 +648,12 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
611648
func (s *Scheduler) unloadAllRunners() {
612649
s.loadedMu.Lock()
613650
defer s.loadedMu.Unlock()
614-
for model, runner := range s.loaded {
615-
if runner.llama != nil {
616-
slog.Debug("shutting down runner", "model", model)
617-
runner.llama.Close()
651+
for model, runners := range s.loaded {
652+
for _, runner := range runners {
653+
if runner.llama != nil {
654+
slog.Debug("shutting down runner", "model", model)
655+
runner.llama.Close()
656+
}
618657
}
619658
}
620659
}

0 commit comments

Comments
 (0)