Skip to content

Commit 99821d3

Browse files
committed
Recheck takeover liveness and classify reconnect failures
1 parent dff7c3c commit 99821d3

5 files changed

Lines changed: 138 additions & 0 deletions

File tree

controlplane/flight_ingress.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package controlplane
33
import (
44
"context"
55
"crypto/tls"
6+
"errors"
67
"fmt"
78
"log/slog"
89
"sync"
@@ -136,6 +137,9 @@ func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record
136137

137138
pid, executor, err := sessions.ReconnectFlightSession(ctx, record.Username, record.WorkerID, record.OwnerEpoch)
138139
if err != nil {
140+
if errors.Is(err, configstore.ErrWorkerOwnerEpochMismatch) {
141+
return 0, nil, flightsqlingress.MarkDurableReconnectTerminal(err)
142+
}
139143
return 0, nil, err
140144
}
141145

controlplane/k8s_pool.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,11 @@ func (p *K8sWorkerPool) reserveClaimedWorker(ctx context.Context, claimed *confi
11801180
if reservedRecord != nil {
11811181
p.persistWorkerRecord(reservedRecord)
11821182
}
1183+
if err := p.checkReservedWorkerLiveness(ctx, worker); err != nil {
1184+
slog.Warn("Claimed worker failed liveness recheck.", "worker", worker.ID, "pod", worker.PodName(), "error", err)
1185+
p.retireWorkerWithReason(worker.ID, RetireReasonCrash)
1186+
return nil, err
1187+
}
11831188
return worker, nil
11841189
}
11851190

controlplane/k8s_pool_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,11 @@ func TestK8sPoolClaimSpecificWorkerTakesOverRuntimeWorker(t *testing.T) {
929929
pool.runtimeStore = store
930930
worker := &ManagedWorker{ID: 44, done: make(chan struct{})}
931931
pool.workers[worker.ID] = worker
932+
livenessChecked := false
933+
pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error {
934+
livenessChecked = true
935+
return nil
936+
}
932937

933938
claimed, err := pool.claimSpecificWorker(context.Background(), 44, 7, &WorkerAssignment{
934939
OrgID: "analytics",
@@ -966,6 +971,9 @@ func TestK8sPoolClaimSpecificWorkerTakesOverRuntimeWorker(t *testing.T) {
966971
if state.Assignment == nil || state.Assignment.OrgID != "analytics" {
967972
t.Fatalf("expected analytics assignment, got %#v", state.Assignment)
968973
}
974+
if !livenessChecked {
975+
t.Fatal("expected claimSpecificWorker to recheck worker liveness")
976+
}
969977
}
970978

971979
func TestK8sPoolClaimSpecificWorkerReturnsEpochMismatchError(t *testing.T) {
@@ -992,6 +1000,42 @@ func TestK8sPoolClaimSpecificWorkerReturnsEpochMismatchError(t *testing.T) {
9921000
}
9931001
}
9941002

1003+
func TestK8sPoolClaimSpecificWorkerRetiresUnhealthyWorker(t *testing.T) {
1004+
pool, _ := newTestK8sPool(t, 5)
1005+
leaseExpiry := time.Date(2026, time.March, 20, 19, 30, 0, 0, time.UTC)
1006+
store := &captureRuntimeWorkerStore{
1007+
takenOver: &configstore.WorkerRecord{
1008+
WorkerID: 44,
1009+
PodName: "duckgres-worker-test-cp-44",
1010+
State: configstore.WorkerStateReserved,
1011+
OrgID: "analytics",
1012+
OwnerCPInstanceID: pool.cpInstanceID,
1013+
OwnerEpoch: 8,
1014+
LeaseExpiresAt: leaseExpiry,
1015+
},
1016+
}
1017+
pool.runtimeStore = store
1018+
pool.workers[44] = &ManagedWorker{ID: 44, done: make(chan struct{})}
1019+
pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error {
1020+
return errors.New("dead worker")
1021+
}
1022+
1023+
claimed, err := pool.claimSpecificWorker(context.Background(), 44, 7, &WorkerAssignment{
1024+
OrgID: "analytics",
1025+
LeaseExpiresAt: leaseExpiry,
1026+
MaxWorkers: 3,
1027+
})
1028+
if err == nil {
1029+
t.Fatal("expected unhealthy claimed worker to fail liveness recheck")
1030+
}
1031+
if claimed != nil {
1032+
t.Fatalf("expected no claimed worker, got %#v", claimed)
1033+
}
1034+
if _, ok := pool.Worker(44); ok {
1035+
t.Fatal("expected unhealthy worker to be retired from the pool")
1036+
}
1037+
}
1038+
9951039
func TestK8sPoolReserveSharedWorkerCreatesRuntimeSpawningSlotWhenPoolIsCold(t *testing.T) {
9961040
pool, _ := newTestK8sPool(t, 5)
9971041
leaseExpiry := time.Date(2026, time.March, 20, 20, 0, 0, 0, time.UTC)

server/flightsqlingress/ingress.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"database/sql"
88
"encoding/base64"
99
"encoding/hex"
10+
"errors"
1011
"fmt"
1112
"log/slog"
1213
"net"
@@ -39,6 +40,15 @@ const (
3940
defaultFlightSessionHeaderKey = "x-duckgres-session"
4041
)
4142

43+
var ErrDurableReconnectTerminal = errors.New("durable reconnect terminal")
44+
45+
func MarkDurableReconnectTerminal(err error) error {
46+
if err == nil {
47+
return nil
48+
}
49+
return fmt.Errorf("%w: %w", ErrDurableReconnectTerminal, err)
50+
}
51+
4252
const (
4353
ReapTriggerPeriodic = "periodic"
4454
ReapTriggerForced = "forced"
@@ -1735,6 +1745,9 @@ func (s *flightAuthSessionStore) reconnectByToken(ctx context.Context, token str
17351745
pid, executor, err := s.reconnector.ReconnectSession(ctx, *record)
17361746
if err != nil {
17371747
slog.Warn("Reconnecting durable Flight session failed.", "token", token, "error", err)
1748+
if errors.Is(err, ErrDurableReconnectTerminal) {
1749+
_ = s.durableStore.CloseSession(token, time.Now())
1750+
}
17381751
return nil, false
17391752
}
17401753

server/flightsqlingress/ingress_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,78 @@ func TestFlightAuthSessionStoreReconnectRefreshesDurableSessionMetadata(t *testi
956956
}
957957
}
958958

959+
func TestFlightAuthSessionStoreReconnectFailureUpdatesDurableSessionState(t *testing.T) {
960+
tests := []struct {
961+
name string
962+
reconnectErr error
963+
wantState DurableSessionState
964+
wantReconnectCall int
965+
}{
966+
{
967+
name: "terminal stale ownership closes durable session",
968+
reconnectErr: MarkDurableReconnectTerminal(errors.New("stale owner")),
969+
wantState: DurableSessionStateClosed,
970+
wantReconnectCall: 1,
971+
},
972+
{
973+
name: "transient reconnect failure leaves durable session active",
974+
reconnectErr: context.DeadlineExceeded,
975+
wantState: DurableSessionStateActive,
976+
wantReconnectCall: 2,
977+
},
978+
}
979+
980+
for _, tt := range tests {
981+
t.Run(tt.name, func(t *testing.T) {
982+
durable := &captureDurableSessionStore{
983+
records: map[string]DurableSessionRecord{
984+
"durable-token": {
985+
SessionToken: "durable-token",
986+
Username: "postgres",
987+
OrgID: "analytics",
988+
WorkerID: 17,
989+
OwnerEpoch: 4,
990+
CPInstanceID: "cp-old:boot-a",
991+
State: DurableSessionStateActive,
992+
ExpiresAt: time.Now().Add(time.Hour),
993+
LastSeenAt: time.Now().Add(-time.Minute),
994+
},
995+
},
996+
}
997+
reconnectCalls := 0
998+
provider := &testDurableSessionProvider{
999+
durableStore: durable,
1000+
reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) {
1001+
reconnectCalls++
1002+
return 0, nil, tt.reconnectErr
1003+
},
1004+
}
1005+
store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{})
1006+
1007+
if session, ok := store.GetByTokenContext(context.Background(), "durable-token"); ok || session != nil {
1008+
t.Fatal("expected durable token reconnect to fail")
1009+
}
1010+
record, err := durable.GetSession("durable-token")
1011+
if err != nil {
1012+
t.Fatalf("GetSession: %v", err)
1013+
}
1014+
if record == nil {
1015+
t.Fatal("expected durable session record to remain present")
1016+
}
1017+
if record.State != tt.wantState {
1018+
t.Fatalf("expected durable session state %q, got %q", tt.wantState, record.State)
1019+
}
1020+
1021+
if session, ok := store.GetByTokenContext(context.Background(), "durable-token"); ok || session != nil {
1022+
t.Fatal("expected second durable token lookup to fail")
1023+
}
1024+
if reconnectCalls != tt.wantReconnectCall {
1025+
t.Fatalf("expected %d reconnect attempts, got %d", tt.wantReconnectCall, reconnectCalls)
1026+
}
1027+
})
1028+
}
1029+
}
1030+
9591031
func TestFlightAuthSessionStoreRejectsNewSessionsWhileDraining(t *testing.T) {
9601032
provider := &testDurableSessionProvider{
9611033
createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) {

0 commit comments

Comments
 (0)