Skip to content

Commit add1ac4

Browse files
committed
save
1 parent 1c84d75 commit add1ac4

File tree

1 file changed

+36
-53
lines changed

1 file changed

+36
-53
lines changed

db/seg/seg_paged_rw.go

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,6 @@ func putPageResult(r *pageResult) {
5656
pageResultPool.Put(r)
5757
}
5858

59-
func drainWorkCh(ch chan *pageWorkItem) {
60-
for {
61-
select {
62-
case item := <-ch:
63-
putPageWorkItem(item)
64-
default:
65-
return
66-
}
67-
}
68-
}
69-
7059
func drainResultCh(ch chan *pageResult) {
7160
for {
7261
select {
@@ -336,7 +325,9 @@ type PagedWriter struct {
336325
numWorkers int
337326
workCh chan *pageWorkItem
338327
resultCh chan *pageResult
339-
eg *errgroup.Group // tracks live worker goroutines; cancels all on first error
328+
eg *errgroup.Group // tracks workers + reducer; cancels all on first error
329+
egCtx context.Context // cancelled on first worker/reducer error or on Flush exit
330+
egCancel context.CancelFunc // cancels egCtx to unblock goroutines on shutdown
340331
seqIn int // next seq to assign to work item
341332
seqOut int // next seq to write to parent
342333
workersShutdown bool // tracks if workers have been shut down
@@ -352,11 +343,20 @@ func (c *PagedWriter) initWorkers() {
352343
c.workCh = make(chan *pageWorkItem, queueDepth)
353344
c.resultCh = make(chan *pageResult, queueDepth)
354345
c.pendingResults = make(map[int]*pageResult, queueDepth)
355-
var egCtx context.Context
356-
c.eg, egCtx = errgroup.WithContext(c.ctx)
346+
cancelCtx, cancel := context.WithCancel(c.ctx)
347+
c.egCancel = cancel
348+
c.eg, c.egCtx = errgroup.WithContext(cancelCtx)
349+
350+
var workerWg sync.WaitGroup
351+
workerWg.Add(c.numWorkers)
357352
for range c.numWorkers {
358-
c.eg.Go(func() error { return c.compressionWorker(egCtx) })
353+
c.eg.Go(func() error {
354+
defer workerWg.Done()
355+
return c.compressionWorker(c.egCtx)
356+
})
359357
}
358+
go func() { workerWg.Wait(); close(c.resultCh) }()
359+
c.eg.Go(c.reducer)
360360
}
361361

362362
func (c *PagedWriter) compressionWorker(ctx context.Context) error {
@@ -391,6 +391,18 @@ func (c *PagedWriter) compressionWorker(ctx context.Context) error {
391391
}
392392
}
393393

394+
func (c *PagedWriter) reducer() error {
395+
for r := range c.resultCh {
396+
c.pendingResults[r.seq] = r
397+
if err := c.writeInOrder(); err != nil {
398+
drainResultCh(c.resultCh)
399+
drainPendingResults(c.pendingResults)
400+
return err
401+
}
402+
}
403+
return nil
404+
}
405+
394406
func (c *PagedWriter) writeInOrder() error {
395407
for {
396408
r, ok := c.pendingResults[c.seqOut]
@@ -482,26 +494,13 @@ func (c *PagedWriter) writePage() error {
482494
item.seq = c.seqIn
483495
c.seqIn++
484496

485-
// Send to workers; if workCh is full, drain resultCh to unblock a worker first
486-
for {
487-
select {
488-
case c.workCh <- item:
489-
sent = true
490-
// Non-blocking drain of any ready results
491-
for {
492-
select {
493-
case r := <-c.resultCh:
494-
c.pendingResults[r.seq] = r
495-
default:
496-
return c.writeInOrder()
497-
}
498-
}
499-
case r := <-c.resultCh:
500-
c.pendingResults[r.seq] = r
501-
case <-c.ctx.Done():
502-
return c.ctx.Err()
503-
}
497+
select {
498+
case c.workCh <- item:
499+
sent = true
500+
case <-c.egCtx.Done():
501+
return c.egCtx.Err()
504502
}
503+
return nil
505504
}
506505

507506
func (c *PagedWriter) Add(k, v []byte) (err error) {
@@ -538,32 +537,16 @@ func (c *PagedWriter) Flush() (err error) {
538537
c.resetPage()
539538
return nil
540539
}
541-
// Signal workers to exit, then drain all pending results in order
540+
// Signal workers to stop; reducer drains resultCh and writes in order
542541
if !c.workersShutdown {
543542
close(c.workCh)
544543
c.workersShutdown = true
545544
}
546545
defer func() {
547-
c.eg.Wait() //nolint:errcheck // ensure all worker goroutines have exited, even on error
548-
if err != nil {
549-
drainWorkCh(c.workCh)
550-
drainResultCh(c.resultCh)
551-
drainPendingResults(c.pendingResults)
552-
}
546+
c.egCancel()
547+
c.eg.Wait() //nolint:errcheck
553548
c.resetPage()
554549
}()
555-
556-
for c.seqOut < c.seqIn {
557-
select {
558-
case r := <-c.resultCh:
559-
c.pendingResults[r.seq] = r
560-
if err = c.writeInOrder(); err != nil {
561-
return
562-
}
563-
case <-c.ctx.Done():
564-
return c.eg.Wait()
565-
}
566-
}
567550
return c.eg.Wait()
568551
}
569552

0 commit comments

Comments
 (0)