Skip to content

Commit ddb2668

Browse files
committed
server: address disconnect monitor review comments
1 parent f16a5a6 commit ddb2668

4 files changed

Lines changed: 141 additions & 83 deletions

File tree

pkg/server/conn.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,27 +2090,37 @@ func setResourceGroupTaggerForMultiStmtPrefetch(snapshot kv.Snapshot, sqls strin
20902090
}
20912091
}
20922092

2093+
// setSQLKillerConnectionAlive installs a connection-liveness probe on the
2094+
// session SQLKiller and starts a background monitor for the current statement.
2095+
// The returned cleanup is idempotent and must be called when the statement is
2096+
// done to stop the monitor and clear the probe.
20932097
func (cc *clientConn) setSQLKillerConnectionAlive() func() {
20942098
fn := func() bool {
20952099
if cc.bufReadConn != nil {
2100+
// IsAlive returns 0 only when the connection is known dead. Treat
2101+
// unknown states as alive so we do not interrupt queries
2102+
// conservatively when the liveness check itself cannot run.
20962103
return cc.bufReadConn.IsAlive() != 0
20972104
}
20982105
return true
20992106
}
21002107
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn)
21012108
stopMonitor := make(chan struct{})
2102-
go cc.monitorConnectionAlive(fn, stopMonitor)
2109+
doneMonitor := make(chan struct{})
2110+
go cc.monitorConnectionAlive(fn, stopMonitor, doneMonitor)
21032111

21042112
var clearOnce sync.Once
21052113
return func() {
21062114
clearOnce.Do(func() {
21072115
close(stopMonitor)
2116+
<-doneMonitor
21082117
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil)
21092118
})
21102119
}
21112120
}
21122121

2113-
func (cc *clientConn) monitorConnectionAlive(isAlive func() bool, stop <-chan struct{}) {
2122+
func (cc *clientConn) monitorConnectionAlive(isAlive func() bool, stop <-chan struct{}, done chan<- struct{}) {
2123+
defer close(done)
21142124
checkInterval := time.Second
21152125
failpoint.Inject("mockConnectionAliveMonitorInterval", func(val failpoint.Value) {
21162126
if interval, ok := val.(int); ok {
@@ -2123,6 +2133,11 @@ func (cc *clientConn) monitorConnectionAlive(isAlive func() bool, stop <-chan st
21232133
select {
21242134
case <-ticker.C:
21252135
if !isAlive() {
2136+
select {
2137+
case <-stop:
2138+
return
2139+
default:
2140+
}
21262141
cc.ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
21272142
cc.cancelDispatch()
21282143
return
@@ -2187,6 +2202,7 @@ func (cc *clientConn) handleStmt(
21872202
monitoringConnectionAlive := shouldMonitorConnectionAliveDuringExecute(stmt, cc.ctx.GetSessionVars())
21882203
if monitoringConnectionAlive {
21892204
clearConnectionAlive = cc.setSQLKillerConnectionAlive()
2205+
defer clearConnectionAlive()
21902206
}
21912207
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
21922208
if rs == nil || err != nil {
@@ -2224,6 +2240,7 @@ func (cc *clientConn) handleStmt(
22242240
}
22252241
if !monitoringConnectionAlive {
22262242
clearConnectionAlive = cc.setSQLKillerConnectionAlive()
2243+
defer clearConnectionAlive()
22272244
}
22282245
cc.ctx.GetSessionVars().SQLKiller.SetFinishFunc(
22292246
func() {
@@ -2233,7 +2250,6 @@ func (cc *clientConn) handleStmt(
22332250
cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true)
22342251
defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false)
22352252
defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc()
2236-
defer clearConnectionAlive()
22372253
if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil {
22382254
return retryable, err
22392255
}

pkg/server/conn_stmt.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
316316
monitoringConnectionAlive = shouldMonitorConnectionAliveDuringExecute(planCacheStmt.PreparedAst.Stmt, vars)
317317
if monitoringConnectionAlive {
318318
clearConnectionAlive = cc.setSQLKillerConnectionAlive()
319+
defer clearConnectionAlive()
319320
}
320321
}
321322
rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt)
@@ -326,8 +327,8 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
326327
if rs != nil {
327328
if !monitoringConnectionAlive {
328329
clearConnectionAlive = cc.setSQLKillerConnectionAlive()
330+
defer clearConnectionAlive()
329331
}
330-
defer clearConnectionAlive()
331332
defer func() {
332333
if !lazy {
333334
rs.Close()

pkg/server/tests/commontest/tidb_test.go

Lines changed: 118 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,99 +3508,100 @@ func TestClientDisconnectKillsAutocommitInsert(t *testing.T) {
35083508
ts := servertestkit.CreateTidbTestSuite(t)
35093509
enableFastConnectionAliveMonitor(t)
35103510

3511-
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3512-
dbt.MustExec("drop table if exists issue57531_insert")
3513-
dbt.MustExec("create table issue57531_insert (a int primary key, b int)")
3514-
3515-
conn, err := dbt.GetDB().Conn(context.Background())
3516-
require.NoError(t, err)
3517-
defer func() {
3518-
_ = conn.Close()
3519-
}()
3520-
netConn := getRawNetConn(t, conn)
3521-
3522-
done := make(chan error, 1)
3523-
go func() {
3524-
_, err := conn.ExecContext(context.Background(), "insert into issue57531_insert values (1, sleep(300))")
3525-
done <- err
3526-
}()
3527-
3528-
require.Eventually(t, func() bool {
3529-
return processlistCountByInfo(t, dbt, "insert into issue57531_insert%") == 1
3530-
}, 5*time.Second, 50*time.Millisecond)
3531-
3532-
require.NoError(t, netConn.Close())
3533-
3534-
var execErr error
3535-
require.Eventually(t, func() bool {
3536-
select {
3537-
case execErr = <-done:
3538-
return true
3539-
default:
3540-
return false
3541-
}
3542-
}, 5*time.Second, 50*time.Millisecond)
3543-
require.Error(t, execErr)
3544-
3545-
require.Eventually(t, func() bool {
3546-
return processlistCountByInfo(t, dbt, "insert into issue57531_insert%") == 0
3547-
}, 5*time.Second, 50*time.Millisecond)
3548-
3549-
var cnt int
3550-
err = dbt.GetDB().QueryRowContext(context.Background(), "select count(*) from issue57531_insert").Scan(&cnt)
3551-
require.NoError(t, err)
3552-
require.Equal(t, 0, cnt)
3553-
})
3511+
for _, prepared := range []bool{false, true} {
3512+
name := "query"
3513+
if prepared {
3514+
name = "prepared"
3515+
}
3516+
t.Run(name, func(t *testing.T) {
3517+
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3518+
tableName := "issue57531_insert_" + name
3519+
dbt.MustExec("drop table if exists " + tableName)
3520+
dbt.MustExec("create table " + tableName + " (a int primary key, b int)")
3521+
runClientDisconnectAutocommitInsert(t, dbt, tableName, fmt.Sprintf("insert into %s values (1, sleep(300))", tableName), prepared)
3522+
})
3523+
})
3524+
}
35543525
}
35553526

35563527
func TestClientDisconnectCancelsAutocommitInsertPrewrite(t *testing.T) {
35573528
ts := servertestkit.CreateTidbTestSuite(t)
35583529
enableFastConnectionAliveMonitor(t)
35593530

3560-
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3561-
dbt.MustExec("drop table if exists issue57531_prewrite")
3562-
dbt.MustExec("create table issue57531_prewrite (a int primary key, b int)")
3563-
testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/store/mockstore/unistore/rpcPrewriteResult", `return("notLeader")`)
3531+
for _, prepared := range []bool{false, true} {
3532+
name := "query"
3533+
if prepared {
3534+
name = "prepared"
3535+
}
3536+
t.Run(name, func(t *testing.T) {
3537+
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3538+
tableName := "issue57531_prewrite_" + name
3539+
dbt.MustExec("drop table if exists " + tableName)
3540+
dbt.MustExec("create table " + tableName + " (a int primary key, b int)")
3541+
testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/store/mockstore/unistore/rpcPrewriteResult", `return("notLeader")`)
3542+
runClientDisconnectAutocommitInsert(t, dbt, tableName, fmt.Sprintf("insert into %s values (1, 1)", tableName), prepared)
3543+
})
3544+
})
3545+
}
3546+
}
35643547

3565-
conn, err := dbt.GetDB().Conn(context.Background())
3548+
func runClientDisconnectAutocommitInsert(t *testing.T, dbt *testkit.DBTestKit, tableName, insertSQL string, prepared bool) {
3549+
conn, err := dbt.GetDB().Conn(context.Background())
3550+
require.NoError(t, err)
3551+
defer func() {
3552+
_ = conn.Close()
3553+
}()
3554+
3555+
var stmt *sql.Stmt
3556+
if prepared {
3557+
stmt, err = conn.PrepareContext(context.Background(), insertSQL)
35663558
require.NoError(t, err)
35673559
defer func() {
3568-
_ = conn.Close()
3560+
_ = stmt.Close()
35693561
}()
3570-
netConn := getRawNetConn(t, conn)
3562+
}
3563+
netConn := getRawNetConn(t, conn)
35713564

3572-
done := make(chan error, 1)
3573-
go func() {
3574-
_, err := conn.ExecContext(context.Background(), "insert into issue57531_prewrite values (1, 1)")
3575-
done <- err
3576-
}()
3565+
done := make(chan error, 1)
3566+
go func() {
3567+
var execErr error
3568+
if prepared {
3569+
_, execErr = stmt.ExecContext(context.Background())
3570+
} else {
3571+
_, execErr = conn.ExecContext(context.Background(), insertSQL)
3572+
}
3573+
done <- execErr
3574+
}()
35773575

3578-
require.Eventually(t, func() bool {
3579-
return processlistCountByInfo(t, dbt, "insert into issue57531_prewrite%") == 1
3580-
}, 5*time.Second, 50*time.Millisecond)
3576+
pattern := fmt.Sprintf("insert into %s%%", tableName)
3577+
require.Eventually(t, func() bool {
3578+
return processlistCountByInfo(t, dbt, pattern) == 1
3579+
}, 5*time.Second, 50*time.Millisecond)
3580+
processID, ok := processlistIDByInfo(t, dbt, pattern)
3581+
require.True(t, ok)
3582+
cleanupProcessByID(t, dbt.GetDB(), processID)
35813583

3582-
require.NoError(t, netConn.Close())
3584+
require.NoError(t, netConn.Close())
35833585

3584-
var execErr error
3585-
require.Eventually(t, func() bool {
3586-
select {
3587-
case execErr = <-done:
3588-
return true
3589-
default:
3590-
return false
3591-
}
3592-
}, 5*time.Second, 50*time.Millisecond)
3593-
require.Error(t, execErr)
3586+
var execErr error
3587+
require.Eventually(t, func() bool {
3588+
select {
3589+
case execErr = <-done:
3590+
return true
3591+
default:
3592+
return false
3593+
}
3594+
}, 5*time.Second, 50*time.Millisecond)
3595+
require.Error(t, execErr)
35943596

3595-
require.Eventually(t, func() bool {
3596-
return processlistCountByInfo(t, dbt, "insert into issue57531_prewrite%") == 0
3597-
}, 5*time.Second, 50*time.Millisecond)
3597+
require.Eventually(t, func() bool {
3598+
return processlistCountByInfo(t, dbt, pattern) == 0
3599+
}, 5*time.Second, 50*time.Millisecond)
35983600

3599-
var cnt int
3600-
err = dbt.GetDB().QueryRowContext(context.Background(), "select count(*) from issue57531_prewrite").Scan(&cnt)
3601-
require.NoError(t, err)
3602-
require.Equal(t, 0, cnt)
3603-
})
3601+
var cnt int
3602+
err = dbt.GetDB().QueryRowContext(context.Background(), "select count(*) from "+tableName).Scan(&cnt)
3603+
require.NoError(t, err)
3604+
require.Equal(t, 0, cnt)
36043605
}
36053606

36063607
func enableFastConnectionAliveMonitor(t *testing.T) {
@@ -3642,6 +3643,46 @@ func processlistCountByInfo(t *testing.T, dbt *testkit.DBTestKit, pattern string
36423643
return cnt
36433644
}
36443645

3646+
func processlistIDByInfo(t *testing.T, dbt *testkit.DBTestKit, pattern string) (uint64, bool) {
3647+
var id uint64
3648+
err := dbt.GetDB().QueryRowContext(
3649+
context.Background(),
3650+
"select id from information_schema.processlist where info like ? limit 1",
3651+
pattern,
3652+
).Scan(&id)
3653+
if err == sql.ErrNoRows {
3654+
return 0, false
3655+
}
3656+
require.NoError(t, err)
3657+
return id, true
3658+
}
3659+
3660+
func cleanupProcessByID(t *testing.T, db *sql.DB, processID uint64) {
3661+
t.Cleanup(func() {
3662+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
3663+
defer cancel()
3664+
3665+
conn, err := db.Conn(ctx)
3666+
if err != nil {
3667+
return
3668+
}
3669+
defer func() {
3670+
_ = conn.Close()
3671+
}()
3672+
3673+
var cnt int
3674+
err = conn.QueryRowContext(
3675+
ctx,
3676+
"select count(*) from information_schema.processlist where id = ?",
3677+
processID,
3678+
).Scan(&cnt)
3679+
if err != nil || cnt == 0 {
3680+
return
3681+
}
3682+
_, _ = conn.ExecContext(ctx, fmt.Sprintf("kill query %d", processID))
3683+
})
3684+
}
3685+
36453686
func TestCloseConnForUndeterminedError(t *testing.T) {
36463687
cfg := util2.NewTestConfig()
36473688
cfg.Host = "127.0.0.1" // No network interface listening for mysql traffic

pkg/util/sqlkiller/sqlkiller.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func (killer *SQLKiller) HandleSignal() error {
220220
} else if now.Sub(*lastCheckTime) > checkConnectionAliveDur {
221221
killer.lastCheckTime.Store(&now)
222222
if !(*fn)() {
223-
atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted)
223+
killer.sendKillSignal(QueryInterrupted)
224224
}
225225
}
226226
}
@@ -238,7 +238,7 @@ func (killer *SQLKiller) HandleSignal() error {
238238
func (killer *SQLKiller) CheckConnectionAlive() {
239239
fn := killer.IsConnectionAlive.Load()
240240
if fn != nil && !(*fn)() {
241-
atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted)
241+
killer.sendKillSignal(QueryInterrupted)
242242
}
243243
}
244244

0 commit comments

Comments
 (0)