@@ -3508,99 +3508,100 @@ func TestClientDisconnectKillsAutocommitInsert(t *testing.T) {
35083508 ts := servertestkit .CreateTidbTestSuite (t )
35093509 enableFastConnectionAliveMonitor (t )
35103510
3511- ts .RunTests (t , nil , func (dbt * testkit.DBTestKit ) {
3512- dbt .MustExec ("drop table if exists issue57531_insert" )
3513- dbt .MustExec ("create table issue57531_insert (a int primary key, b int)" )
3514-
3515- conn , err := dbt .GetDB ().Conn (context .Background ())
3516- require .NoError (t , err )
3517- defer func () {
3518- _ = conn .Close ()
3519- }()
3520- netConn := getRawNetConn (t , conn )
3521-
3522- done := make (chan error , 1 )
3523- go func () {
3524- _ , err := conn .ExecContext (context .Background (), "insert into issue57531_insert values (1, sleep(300))" )
3525- done <- err
3526- }()
3527-
3528- require .Eventually (t , func () bool {
3529- return processlistCountByInfo (t , dbt , "insert into issue57531_insert%" ) == 1
3530- }, 5 * time .Second , 50 * time .Millisecond )
3531-
3532- require .NoError (t , netConn .Close ())
3533-
3534- var execErr error
3535- require .Eventually (t , func () bool {
3536- select {
3537- case execErr = <- done :
3538- return true
3539- default :
3540- return false
3541- }
3542- }, 5 * time .Second , 50 * time .Millisecond )
3543- require .Error (t , execErr )
3544-
3545- require .Eventually (t , func () bool {
3546- return processlistCountByInfo (t , dbt , "insert into issue57531_insert%" ) == 0
3547- }, 5 * time .Second , 50 * time .Millisecond )
3548-
3549- var cnt int
3550- err = dbt .GetDB ().QueryRowContext (context .Background (), "select count(*) from issue57531_insert" ).Scan (& cnt )
3551- require .NoError (t , err )
3552- require .Equal (t , 0 , cnt )
3553- })
3511+ for _ , prepared := range []bool {false , true } {
3512+ name := "query"
3513+ if prepared {
3514+ name = "prepared"
3515+ }
3516+ t .Run (name , func (t * testing.T ) {
3517+ ts .RunTests (t , nil , func (dbt * testkit.DBTestKit ) {
3518+ tableName := "issue57531_insert_" + name
3519+ dbt .MustExec ("drop table if exists " + tableName )
3520+ dbt .MustExec ("create table " + tableName + " (a int primary key, b int)" )
3521+ runClientDisconnectAutocommitInsert (t , dbt , tableName , fmt .Sprintf ("insert into %s values (1, sleep(300))" , tableName ), prepared )
3522+ })
3523+ })
3524+ }
35543525}
35553526
35563527func TestClientDisconnectCancelsAutocommitInsertPrewrite (t * testing.T ) {
35573528 ts := servertestkit .CreateTidbTestSuite (t )
35583529 enableFastConnectionAliveMonitor (t )
35593530
3560- ts .RunTests (t , nil , func (dbt * testkit.DBTestKit ) {
3561- dbt .MustExec ("drop table if exists issue57531_prewrite" )
3562- dbt .MustExec ("create table issue57531_prewrite (a int primary key, b int)" )
3563- testfailpoint .Enable (t , "github.com/pingcap/tidb/pkg/store/mockstore/unistore/rpcPrewriteResult" , `return("notLeader")` )
3531+ for _ , prepared := range []bool {false , true } {
3532+ name := "query"
3533+ if prepared {
3534+ name = "prepared"
3535+ }
3536+ t .Run (name , func (t * testing.T ) {
3537+ ts .RunTests (t , nil , func (dbt * testkit.DBTestKit ) {
3538+ tableName := "issue57531_prewrite_" + name
3539+ dbt .MustExec ("drop table if exists " + tableName )
3540+ dbt .MustExec ("create table " + tableName + " (a int primary key, b int)" )
3541+ testfailpoint .Enable (t , "github.com/pingcap/tidb/pkg/store/mockstore/unistore/rpcPrewriteResult" , `return("notLeader")` )
3542+ runClientDisconnectAutocommitInsert (t , dbt , tableName , fmt .Sprintf ("insert into %s values (1, 1)" , tableName ), prepared )
3543+ })
3544+ })
3545+ }
3546+ }
35643547
3565- conn , err := dbt .GetDB ().Conn (context .Background ())
3548+ func runClientDisconnectAutocommitInsert (t * testing.T , dbt * testkit.DBTestKit , tableName , insertSQL string , prepared bool ) {
3549+ conn , err := dbt .GetDB ().Conn (context .Background ())
3550+ require .NoError (t , err )
3551+ defer func () {
3552+ _ = conn .Close ()
3553+ }()
3554+
3555+ var stmt * sql.Stmt
3556+ if prepared {
3557+ stmt , err = conn .PrepareContext (context .Background (), insertSQL )
35663558 require .NoError (t , err )
35673559 defer func () {
3568- _ = conn .Close ()
3560+ _ = stmt .Close ()
35693561 }()
3570- netConn := getRawNetConn (t , conn )
3562+ }
3563+ netConn := getRawNetConn (t , conn )
35713564
3572- done := make (chan error , 1 )
3573- go func () {
3574- _ , err := conn .ExecContext (context .Background (), "insert into issue57531_prewrite values (1, 1)" )
3575- done <- err
3576- }()
3565+ done := make (chan error , 1 )
3566+ go func () {
3567+ var execErr error
3568+ if prepared {
3569+ _ , execErr = stmt .ExecContext (context .Background ())
3570+ } else {
3571+ _ , execErr = conn .ExecContext (context .Background (), insertSQL )
3572+ }
3573+ done <- execErr
3574+ }()
35773575
3578- require .Eventually (t , func () bool {
3579- return processlistCountByInfo (t , dbt , "insert into issue57531_prewrite%" ) == 1
3580- }, 5 * time .Second , 50 * time .Millisecond )
3576+ pattern := fmt .Sprintf ("insert into %s%%" , tableName )
3577+ require .Eventually (t , func () bool {
3578+ return processlistCountByInfo (t , dbt , pattern ) == 1
3579+ }, 5 * time .Second , 50 * time .Millisecond )
3580+ processID , ok := processlistIDByInfo (t , dbt , pattern )
3581+ require .True (t , ok )
3582+ cleanupProcessByID (t , dbt .GetDB (), processID )
35813583
3582- require .NoError (t , netConn .Close ())
3584+ require .NoError (t , netConn .Close ())
35833585
3584- var execErr error
3585- require .Eventually (t , func () bool {
3586- select {
3587- case execErr = <- done :
3588- return true
3589- default :
3590- return false
3591- }
3592- }, 5 * time .Second , 50 * time .Millisecond )
3593- require .Error (t , execErr )
3586+ var execErr error
3587+ require .Eventually (t , func () bool {
3588+ select {
3589+ case execErr = <- done :
3590+ return true
3591+ default :
3592+ return false
3593+ }
3594+ }, 5 * time .Second , 50 * time .Millisecond )
3595+ require .Error (t , execErr )
35943596
3595- require .Eventually (t , func () bool {
3596- return processlistCountByInfo (t , dbt , "insert into issue57531_prewrite%" ) == 0
3597- }, 5 * time .Second , 50 * time .Millisecond )
3597+ require .Eventually (t , func () bool {
3598+ return processlistCountByInfo (t , dbt , pattern ) == 0
3599+ }, 5 * time .Second , 50 * time .Millisecond )
35983600
3599- var cnt int
3600- err = dbt .GetDB ().QueryRowContext (context .Background (), "select count(*) from issue57531_prewrite" ).Scan (& cnt )
3601- require .NoError (t , err )
3602- require .Equal (t , 0 , cnt )
3603- })
3601+ var cnt int
3602+ err = dbt .GetDB ().QueryRowContext (context .Background (), "select count(*) from " + tableName ).Scan (& cnt )
3603+ require .NoError (t , err )
3604+ require .Equal (t , 0 , cnt )
36043605}
36053606
36063607func enableFastConnectionAliveMonitor (t * testing.T ) {
@@ -3642,6 +3643,46 @@ func processlistCountByInfo(t *testing.T, dbt *testkit.DBTestKit, pattern string
36423643 return cnt
36433644}
36443645
3646+ func processlistIDByInfo (t * testing.T , dbt * testkit.DBTestKit , pattern string ) (uint64 , bool ) {
3647+ var id uint64
3648+ err := dbt .GetDB ().QueryRowContext (
3649+ context .Background (),
3650+ "select id from information_schema.processlist where info like ? limit 1" ,
3651+ pattern ,
3652+ ).Scan (& id )
3653+ if err == sql .ErrNoRows {
3654+ return 0 , false
3655+ }
3656+ require .NoError (t , err )
3657+ return id , true
3658+ }
3659+
3660+ func cleanupProcessByID (t * testing.T , db * sql.DB , processID uint64 ) {
3661+ t .Cleanup (func () {
3662+ ctx , cancel := context .WithTimeout (context .Background (), time .Second )
3663+ defer cancel ()
3664+
3665+ conn , err := db .Conn (ctx )
3666+ if err != nil {
3667+ return
3668+ }
3669+ defer func () {
3670+ _ = conn .Close ()
3671+ }()
3672+
3673+ var cnt int
3674+ err = conn .QueryRowContext (
3675+ ctx ,
3676+ "select count(*) from information_schema.processlist where id = ?" ,
3677+ processID ,
3678+ ).Scan (& cnt )
3679+ if err != nil || cnt == 0 {
3680+ return
3681+ }
3682+ _ , _ = conn .ExecContext (ctx , fmt .Sprintf ("kill query %d" , processID ))
3683+ })
3684+ }
3685+
36453686func TestCloseConnForUndeterminedError (t * testing.T ) {
36463687 cfg := util2 .NewTestConfig ()
36473688 cfg .Host = "127.0.0.1" // No network interface listening for mysql traffic
0 commit comments