Skip to content

Commit 73a168d

Browse files
committed
vm/dispatcher: make pool.Run cancellable
Make the pool.Run() function take a context.Context to be able to abort the callback passed to it or abort its scheduling if it's not yet running. Otherwise, if the callback is not yet started and the pool's Loop is aborted, we risk waiting for pool.Run() forever. It prevents the normal shutdown of repro.Run() and, consequently, the DiffFuzzer functionality.
1 parent d971f7e commit 73a168d

File tree

5 files changed

+98
-17
lines changed

5 files changed

+98
-17
lines changed

pkg/manager/diff.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package manager
66
import (
77
"context"
88
"encoding/json"
9+
"errors"
910
"fmt"
1011
"math/rand"
1112
"net"
@@ -583,7 +584,7 @@ func (rr *reproRunner) Run(ctx context.Context, r *repro.Result) {
583584
// The third time we leave it as is in case it was important.
584585
opts.Threaded = true
585586
}
586-
pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
587+
runErr := pool.Run(ctx, func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
587588
var ret *instance.ExecProgInstance
588589
ret, err = instance.SetupExecProg(inst, rr.kernel.cfg, rr.kernel.reporter, nil)
589590
if err != nil {
@@ -595,6 +596,9 @@ func (rr *reproRunner) Run(ctx context.Context, r *repro.Result) {
595596
Opts: opts,
596597
})
597598
})
599+
if errors.Is(runErr, context.Canceled) {
600+
break
601+
}
598602
crashed := result != nil && result.Report != nil
599603
log.Logf(1, "attempt #%d to run %q on base: crashed=%v", i, ret.origReport.Title, crashed)
600604
if crashed {

pkg/repro/repro.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ func (pw *poolWrapper) Run(ctx context.Context, params instance.ExecParams,
767767

768768
var result *instance.RunResult
769769
var err error
770-
pw.pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
770+
runErr := pw.pool.Run(ctx, func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
771771
updInfo(func(info *dispatcher.Info) {
772772
typ := "syz"
773773
if params.CProg != nil {
@@ -787,6 +787,9 @@ func (pw *poolWrapper) Run(ctx context.Context, params instance.ExecParams,
787787
result, err = ret.RunSyzProg(params)
788788
}
789789
})
790+
if runErr != nil {
791+
return nil, runErr
792+
}
790793
return result, err
791794
}
792795

pkg/repro/strace.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func RunStrace(result *Result, cfg *mgrconfig.Config, reporter *report.Reporter,
3131
}
3232
var runRes *instance.RunResult
3333
var err error
34-
pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
34+
runErr := pool.Run(context.Background(), func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
3535
updInfo(func(info *dispatcher.Info) {
3636
info.Status = "running strace"
3737
})
@@ -58,7 +58,9 @@ func RunStrace(result *Result, cfg *mgrconfig.Config, reporter *report.Reporter,
5858
runRes, err = ret.RunSyzProg(params)
5959
}
6060
})
61-
if err != nil {
61+
if runErr != nil {
62+
return straceFailed(runErr)
63+
} else if err != nil {
6264
return straceFailed(err)
6365
}
6466
return &StraceResult{

vm/dispatcher/pool.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (p *Pool[T]) Loop(ctx context.Context) {
114114
func (p *Pool[T]) runInstance(ctx context.Context, inst *poolInstance[T]) {
115115
p.waitUnpaused()
116116
ctx, cancel := context.WithCancel(ctx)
117-
117+
defer cancel()
118118
log.Logf(2, "pool: booting instance %d", inst.idx)
119119

120120
inst.reset(cancel)
@@ -187,13 +187,24 @@ func (p *Pool[T]) ReserveForRun(count int) {
187187
}
188188

189189
// Run blocks until it has found an instance to execute job and until job has finished.
190-
func (p *Pool[T]) Run(job Runner[T]) {
191-
done := make(chan struct{})
192-
p.jobs <- func(ctx context.Context, inst T, upd UpdateInfo) {
193-
job(ctx, inst, upd)
194-
close(done)
190+
// Returns an error if the job was aborted by cancelling the context.
191+
func (p *Pool[T]) Run(ctx context.Context, job Runner[T]) error {
192+
done := make(chan error)
193+
// Submit the job.
194+
select {
195+
case p.jobs <- func(jobCtx context.Context, inst T, upd UpdateInfo) {
196+
mergedCtx, cancel := mergeContextCancel(jobCtx, ctx)
197+
defer cancel()
198+
199+
job(mergedCtx, inst, upd)
200+
done <- mergedCtx.Err()
201+
}:
202+
case <-ctx.Done():
203+
// If the loop is aborted, no one is going to pick up the job.
204+
return ctx.Err()
195205
}
196-
<-done
206+
// Await the job.
207+
return <-done
197208
}
198209

199210
func (p *Pool[T]) Total() int {
@@ -311,3 +322,15 @@ func (pi *poolInstance[T]) free(job Runner[T]) {
311322
default:
312323
}
313324
}
325+
326+
func mergeContextCancel(main, monitor context.Context) (context.Context, func()) {
327+
withCancel, cancel := context.WithCancel(main)
328+
go func() {
329+
select {
330+
case <-withCancel.Done():
331+
case <-monitor.Done():
332+
}
333+
cancel()
334+
}()
335+
return withCancel, cancel
336+
}

vm/dispatcher/pool_test.go

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package dispatcher
66
import (
77
"context"
88
"runtime"
9+
"sync"
910
"sync/atomic"
1011
"testing"
1112
"time"
@@ -87,7 +88,7 @@ func TestPoolSplit(t *testing.T) {
8788
case <-stopRuns:
8889
}
8990
}
90-
go mgr.Run(job)
91+
go mgr.Run(ctx, job)
9192

9293
// So far, there are no reserved instances.
9394
for i := 0; i < count; i++ {
@@ -113,7 +114,7 @@ func TestPoolSplit(t *testing.T) {
113114

114115
// Now let's create and finish more jobs.
115116
for i := 0; i < 10; i++ {
116-
go mgr.Run(job)
117+
go mgr.Run(ctx, job)
117118
}
118119
mgr.ReserveForRun(2)
119120
for i := 0; i < 10; i++ {
@@ -150,8 +151,7 @@ func TestPoolStress(t *testing.T) {
150151
}
151152
}()
152153
for i := 0; i < 128; i++ {
153-
go mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {
154-
})
154+
go mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {})
155155
mgr.ReserveForRun(5 + i%5)
156156
}
157157

@@ -221,7 +221,7 @@ func TestPoolPause(t *testing.T) {
221221
}()
222222

223223
run := make(chan bool, 1)
224-
go mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {
224+
go mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {
225225
run <- true
226226
})
227227
time.Sleep(10 * time.Millisecond)
@@ -231,12 +231,61 @@ func TestPoolPause(t *testing.T) {
231231
mgr.TogglePause(false)
232232
<-run
233233

234-
mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {})
234+
mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {})
235235

236236
cancel()
237237
<-done
238238
}
239239

240+
func TestPoolCancelRun(t *testing.T) {
241+
// The test to aid the race detector.
242+
mgr := NewPool[*nilInstance](
243+
10,
244+
func(idx int) (*nilInstance, error) {
245+
return &nilInstance{}, nil
246+
},
247+
func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {
248+
<-ctx.Done()
249+
},
250+
)
251+
var wg sync.WaitGroup
252+
wg.Add(1)
253+
ctx, cancel := context.WithCancel(context.Background())
254+
go func() {
255+
mgr.Loop(ctx)
256+
wg.Done()
257+
}()
258+
259+
mgr.ReserveForRun(2)
260+
261+
started := make(chan struct{})
262+
// Schedule more jobs than could be processed simultaneously.
263+
for i := 0; i < 15; i++ {
264+
wg.Add(1)
265+
go func() {
266+
defer wg.Done()
267+
mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {
268+
select {
269+
case <-ctx.Done():
270+
return
271+
case started <- struct{}{}:
272+
}
273+
<-ctx.Done()
274+
})
275+
}()
276+
}
277+
278+
// Two can be started.
279+
<-started
280+
<-started
281+
282+
// Now stop the loop and the jbos.
283+
cancel()
284+
285+
// Everything must really stop.
286+
wg.Wait()
287+
}
288+
240289
func makePool(count int) []testInstance {
241290
var ret []testInstance
242291
for i := 0; i < count; i++ {

0 commit comments

Comments
 (0)