Skip to content

Commit 47a6dcd

Browse files
authored
Wait for worker query close before session reuse (#751)
* Wait for worker query close before session reuse * Trim worker close wait path * Reduce close wait test setup duplication * Cover cancel reuse close wait edges * Fix cancel reuse e2e client * Release worker query handle when DoGet fails * Cancel DoGet before schema-error wait
1 parent 57f5332 commit 47a6dcd

9 files changed

Lines changed: 877 additions & 33 deletions

File tree

duckdbservice/flight_handler.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,58 @@ func (h *FlightSQLHandler) doHealthCheck(body []byte, stream flight.FlightServic
441441
return sendActionResult(stream, &flight.Result{Body: resp})
442442
}
443443

444+
func (h *FlightSQLHandler) doWaitSessionIdle(body []byte, stream flight.FlightService_DoActionServer) error {
445+
var req server.WorkerWaitSessionIdlePayload
446+
if err := json.Unmarshal(body, &req); err != nil {
447+
return status.Errorf(codes.InvalidArgument, "invalid WaitSessionIdle request: %v", err)
448+
}
449+
if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil {
450+
return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err)
451+
}
452+
session, err := h.sessionFromContext(stream.Context())
453+
if err != nil {
454+
return err
455+
}
456+
if err := session.waitOperationIdle(stream.Context()); err != nil {
457+
switch {
458+
case errors.Is(err, context.Canceled):
459+
return status.Errorf(codes.Canceled, "wait for session idle: %v", err)
460+
case errors.Is(err, context.DeadlineExceeded):
461+
return status.Errorf(codes.DeadlineExceeded, "wait for session idle: %v", err)
462+
default:
463+
return status.Errorf(codes.Internal, "wait for session idle: %v", err)
464+
}
465+
}
466+
467+
resp, _ := json.Marshal(map[string]bool{"ok": true})
468+
return sendActionResult(stream, &flight.Result{Body: resp})
469+
}
470+
471+
func (h *FlightSQLHandler) doReleaseQueryHandle(body []byte, stream flight.FlightService_DoActionServer) error {
472+
var req server.WorkerReleaseQueryHandlePayload
473+
if err := json.Unmarshal(body, &req); err != nil {
474+
return status.Errorf(codes.InvalidArgument, "invalid ReleaseQueryHandle request: %v", err)
475+
}
476+
if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil {
477+
return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err)
478+
}
479+
session, err := h.sessionFromContext(stream.Context())
480+
if err != nil {
481+
return err
482+
}
483+
ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: req.Ticket})
484+
if err != nil {
485+
return status.Errorf(codes.InvalidArgument, "invalid statement ticket: %v", err)
486+
}
487+
handleID := string(ticket.GetStatementHandle())
488+
if handle, ok := popQueryHandle(session, handleID); ok {
489+
releaseQueryHandleValue(handle)
490+
}
491+
492+
resp, _ := json.Marshal(map[string]bool{"ok": true})
493+
return sendActionResult(stream, &flight.Result{Body: resp})
494+
}
495+
444496
// Flight SQL method implementations
445497

446498
func (h *FlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery,

duckdbservice/flight_handler_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,124 @@ func TestCreateSessionRejectsWhileDraining(t *testing.T) {
473473
}
474474
}
475475

476+
func TestWaitSessionIdleBlocksUntilOperationReleases(t *testing.T) {
477+
session := &Session{
478+
ID: "session-1",
479+
Username: "alice",
480+
CreatedAt: time.Now(),
481+
queries: make(map[string]*QueryHandle),
482+
txns: make(map[string]*trackedTx),
483+
txnOwner: make(map[string]string),
484+
}
485+
pool := &SessionPool{
486+
sessions: map[string]*Session{session.ID: session},
487+
stopRefresh: make(map[string]func()),
488+
warmupDone: make(chan struct{}),
489+
startTime: time.Now(),
490+
}
491+
close(pool.warmupDone)
492+
handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator}
493+
494+
finishOperation, ok := session.beginOperation()
495+
if !ok {
496+
t.Fatal("beginOperation rejected test session")
497+
}
498+
499+
body, err := json.Marshal(server.WorkerWaitSessionIdlePayload{})
500+
if err != nil {
501+
t.Fatalf("marshal request: %v", err)
502+
}
503+
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID))
504+
stream := &mockDoActionStream{ctx: ctx}
505+
done := make(chan error, 1)
506+
go func() {
507+
done <- handler.doWaitSessionIdle(body, stream)
508+
}()
509+
510+
select {
511+
case err := <-done:
512+
t.Fatalf("WaitSessionIdle returned before operation released: %v", err)
513+
case <-time.After(50 * time.Millisecond):
514+
}
515+
516+
finishOperation()
517+
select {
518+
case err := <-done:
519+
if err != nil {
520+
t.Fatalf("WaitSessionIdle returned error: %v", err)
521+
}
522+
case <-time.After(time.Second):
523+
t.Fatal("WaitSessionIdle did not return after operation released")
524+
}
525+
if len(stream.results) != 1 {
526+
t.Fatalf("expected one action result, got %d", len(stream.results))
527+
}
528+
}
529+
530+
func TestReleaseQueryHandleReleasesAbandonedOperation(t *testing.T) {
531+
pool := &SessionPool{
532+
sessions: make(map[string]*Session),
533+
stopRefresh: make(map[string]func()),
534+
warmupDone: make(chan struct{}),
535+
startTime: time.Now(),
536+
}
537+
close(pool.warmupDone)
538+
session := &Session{
539+
ID: "session-1",
540+
Username: "alice",
541+
CreatedAt: time.Now(),
542+
queries: make(map[string]*QueryHandle),
543+
txns: make(map[string]*trackedTx),
544+
txnOwner: make(map[string]string),
545+
}
546+
pool.sessions[session.ID] = session
547+
handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator}
548+
549+
finishOperation, ok := session.beginOperation()
550+
if !ok {
551+
t.Fatal("beginOperation rejected test session")
552+
}
553+
finishDrain, err := pool.beginDrainWork(false)
554+
if err != nil {
555+
t.Fatalf("beginDrainWork: %v", err)
556+
}
557+
session.queries["query-1"] = &QueryHandle{
558+
Query: "SELECT 1",
559+
createdAt: time.Now(),
560+
finishDrain: finishDrain,
561+
finishOperation: finishOperation,
562+
}
563+
564+
ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1"))
565+
if err != nil {
566+
t.Fatalf("create ticket: %v", err)
567+
}
568+
body, err := json.Marshal(server.WorkerReleaseQueryHandlePayload{Ticket: ticket})
569+
if err != nil {
570+
t.Fatalf("marshal request: %v", err)
571+
}
572+
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID))
573+
stream := &mockDoActionStream{ctx: ctx}
574+
575+
if err := handler.doReleaseQueryHandle(body, stream); err != nil {
576+
t.Fatalf("ReleaseQueryHandle: %v", err)
577+
}
578+
if _, ok := session.queries["query-1"]; ok {
579+
t.Fatal("query handle was not removed")
580+
}
581+
if got := pool.ActiveDrainWork(); got != 0 {
582+
t.Fatalf("active drain work=%d, want 0", got)
583+
}
584+
finishOperation2, ok := session.beginOperation()
585+
if !ok {
586+
t.Fatal("operation gate was not released")
587+
}
588+
finishOperation2()
589+
if len(stream.results) != 1 {
590+
t.Fatalf("expected one action result, got %d", len(stream.results))
591+
}
592+
}
593+
476594
func TestCreateSessionSendFailureDestroysSession(t *testing.T) {
477595
pool := &SessionPool{
478596
sessions: make(map[string]*Session),

duckdbservice/service.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ type Session struct {
123123
txnOwner map[string]string
124124
closed bool
125125
operationOpen bool
126+
operationIdle chan struct{}
126127
connWork int
127128
connWorkDone *sync.Cond
128129
handleCounter atomic.Uint64
@@ -234,13 +235,18 @@ func (s *Session) beginOperation() (func(), bool) {
234235
return nil, false
235236
}
236237
s.operationOpen = true
238+
s.operationIdle = make(chan struct{})
237239
s.mu.Unlock()
238240

239241
var once sync.Once
240242
return func() {
241243
once.Do(func() {
242244
s.mu.Lock()
243245
s.operationOpen = false
246+
if s.operationIdle != nil {
247+
close(s.operationIdle)
248+
s.operationIdle = nil
249+
}
244250
s.mu.Unlock()
245251
})
246252
}, true
@@ -257,18 +263,45 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo
257263
return nil, true, false
258264
}
259265
s.operationOpen = true
266+
s.operationIdle = make(chan struct{})
260267
s.mu.Unlock()
261268

262269
var once sync.Once
263270
return func() {
264271
once.Do(func() {
265272
s.mu.Lock()
266273
s.operationOpen = false
274+
if s.operationIdle != nil {
275+
close(s.operationIdle)
276+
s.operationIdle = nil
277+
}
267278
s.mu.Unlock()
268279
})
269280
}, true, true
270281
}
271282

283+
func (s *Session) waitOperationIdle(ctx context.Context) error {
284+
for {
285+
s.mu.RLock()
286+
if !s.operationOpen {
287+
s.mu.RUnlock()
288+
return nil
289+
}
290+
idle := s.operationIdle
291+
s.mu.RUnlock()
292+
293+
if idle == nil {
294+
return errors.New("session operation idle signal missing")
295+
}
296+
select {
297+
case <-idle:
298+
return nil
299+
case <-ctx.Done():
300+
return ctx.Err()
301+
}
302+
}
303+
}
304+
272305
// beginConnWork fences any operation that uses the session connection while a
273306
// raw SQL transaction may be open. It intentionally does not mutate queryActive:
274307
// conn work includes COPY receive and metadata/planning work, while queryActive
@@ -1651,6 +1684,10 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe
16511684
return s.handler.doDestroySession(cmd.Body, stream)
16521685
case "HealthCheck":
16531686
return s.handler.doHealthCheck(cmd.Body, stream)
1687+
case "WaitSessionIdle":
1688+
return s.handler.doWaitSessionIdle(cmd.Body, stream)
1689+
case "ReleaseQueryHandle":
1690+
return s.handler.doReleaseQueryHandle(cmd.Body, stream)
16541691
default:
16551692
// Fall through to standard flightsql action router (BeginTransaction, etc.)
16561693
return s.FlightServer.DoAction(cmd, stream)

0 commit comments

Comments
 (0)