Skip to content

Commit 7d3162c

Browse files
authored
ttl: honor scan task cancellation across statement boundaries (#67285)
ref #66982
1 parent 9670037 commit 7d3162c

10 files changed

Lines changed: 150 additions & 22 deletions

File tree

pkg/dxf/framework/storage/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ go_test(
4949
],
5050
embed = [":storage"],
5151
flaky = True,
52-
shard_count = 28,
52+
shard_count = 29,
5353
deps = [
5454
"//pkg/config",
5555
"//pkg/dxf/framework/proto",

pkg/dxf/framework/storage/task_state_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ import (
2121
"slices"
2222
"sync/atomic"
2323
"testing"
24+
"time"
2425

2526
"github.com/pingcap/tidb/pkg/dxf/framework/proto"
2627
"github.com/pingcap/tidb/pkg/dxf/framework/storage"
2728
"github.com/pingcap/tidb/pkg/dxf/framework/testutil"
2829
"github.com/pingcap/tidb/pkg/kv"
2930
"github.com/pingcap/tidb/pkg/sessionctx"
31+
"github.com/pingcap/tidb/pkg/testkit"
3032
"github.com/pingcap/tidb/pkg/testkit/testfailpoint"
3133
tidbutil "github.com/pingcap/tidb/pkg/util"
3234
"github.com/pingcap/tidb/pkg/util/sqlexec"
@@ -135,6 +137,33 @@ func TestTaskState(t *testing.T) {
135137
checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone)
136138
}
137139

140+
func TestWithNewTxnRollbackOnCanceledCtx(t *testing.T) {
141+
_, _ = testkit.CreateMockStoreAndDomain(t)
142+
gm, err := storage.GetTaskManager()
143+
require.NoError(t, err)
144+
145+
ctx, cancel := context.WithCancel(util.WithInternalSourceType(context.Background(), kv.InternalDistTask))
146+
require.NotPanics(t, func() {
147+
err := gm.WithNewTxn(ctx, func(se sessionctx.Context) error {
148+
timer := time.AfterFunc(100*time.Millisecond, cancel)
149+
defer timer.Stop()
150+
151+
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "select sleep(10)")
152+
if err != nil {
153+
return err
154+
}
155+
return ctx.Err()
156+
})
157+
require.ErrorIs(t, err, context.Canceled)
158+
})
159+
160+
verifyCtx := util.WithInternalSourceType(context.Background(), kv.InternalDistTask)
161+
require.NoError(t, gm.WithNewTxn(verifyCtx, func(se sessionctx.Context) error {
162+
_, err := sqlexec.ExecSQL(verifyCtx, se.GetSQLExecutor(), "select 1")
163+
return err
164+
}))
165+
}
166+
138167
func TestUpdateTaskExtraParams(t *testing.T) {
139168
_, gm, ctx := testutil.InitTableTest(t)
140169
require.NoError(t, gm.InitMeta(ctx, ":4000", ""))

pkg/dxf/framework/storage/task_table.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,21 +213,25 @@ func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) err
213213
func (mgr *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
214214
ctx = clitutil.WithInternalSourceType(ctx, kv.InternalDistTask)
215215
return mgr.WithNewSession(func(se sessionctx.Context) (err error) {
216+
// Keep BEGIN on the SQL path so the session enters transaction mode with the usual statement semantics.
217+
// Commit / rollback use session methods instead, because cleanup still has to finish after caller
218+
// cancellation and issuing SQL text there can leave the pooled internal session with a live txn.
216219
_, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "begin")
217220
if err != nil {
218221
return err
219222
}
220223

221224
success := false
222225
defer func() {
223-
sql := "rollback"
224226
if success {
225-
sql = "commit"
226-
}
227-
_, commitErr := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql)
228-
if err == nil && commitErr != nil {
229-
err = commitErr
227+
commitErr := se.CommitTxn(ctx)
228+
if err == nil && commitErr != nil {
229+
err = commitErr
230+
}
231+
return
230232
}
233+
234+
se.RollbackTxn(clitutil.WithInternalSourceType(context.Background(), kv.InternalDistTask))
231235
}()
232236

233237
if err = fn(se); err != nil {

pkg/executor/select.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
10031003
vars.MemTracker.SessionID.Store(vars.ConnectionID)
10041004
vars.MemTracker.Killer = &vars.SQLKiller
10051005
vars.DiskTracker.Killer = &vars.SQLKiller
1006+
if vars.InRestrictedSQL && vars.InternalSQLScanUserTable {
1007+
failpoint.InjectCall("beforeResetSQLKillerForTTLScan", s)
1008+
}
10061009
vars.SQLKiller.Reset()
10071010
vars.SQLKiller.ConnID.Store(vars.ConnectionID)
10081011
vars.ResetRelevantOptVarsAndFixes(false)

pkg/executor/test/executor/executor_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2589,7 +2589,7 @@ func TestQueryWithKill(t *testing.T) {
25892589
}
25902590
}
25912591
if err != nil {
2592-
require.Equal(t, context.Canceled, err)
2592+
require.ErrorIs(t, err, context.Canceled)
25932593
}
25942594
if rs != nil {
25952595
rs.Close()

pkg/session/session.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,10 @@ func (s *session) executeStmtImpl(ctx context.Context, stmtNode ast.StmtNode) (s
24412441
if err := executor.ResetContextOfStmt(s, stmtNode); err != nil {
24422442
return nil, err
24432443
}
2444+
// ResetContextOfStmt clears SQLKiller, so honor a canceled caller before executing the next statement.
2445+
if err := ctx.Err(); err != nil {
2446+
return nil, err
2447+
}
24442448
ruv2Metrics := execdetails.RUV2MetricsFromContext(ctx)
24452449
if ruv2Metrics == nil {
24462450
ruv2Metrics = execdetails.NewRUV2Metrics()

pkg/ttl/ttlworker/scan.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,14 @@ func (t *ttlScanTask) doScan(ctx context.Context, delCh chan<- *ttlDeleteTask, s
137137
}
138138

139139
func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDeleteTask, rawSess session.Session) error {
140-
// TODO: merge the ctx and the taskCtx in ttl scan task, to allow both "cancel" and gracefully stop workers
141-
// now, the taskCtx is only check at the beginning of every loop
142140
taskCtx := t.ctx
143141
tracer := metrics.PhaseTracerFromCtx(ctx)
144142
defer tracer.EnterPhase(tracer.Phase())
145143

146144
tracer.EnterPhase(metrics.PhaseOther)
145+
// Keep the SQL execution context canceled when either the worker or the TTL task stops.
146+
scanCtx, cancelScanCtx := context.WithCancel(ctx)
147+
defer cancelScanCtx()
147148
doScanFinished, setDoScanFinished := context.WithCancel(context.Background())
148149
wg := util.WaitGroupWrapper{}
149150
wg.Run(func() {
@@ -153,6 +154,7 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe
153154
case <-doScanFinished.Done():
154155
return
155156
}
157+
cancelScanCtx()
156158
logger := t.taskLogger(logutil.BgLogger())
157159
logger.Info("kill the running statement in scan task because the task or worker cancelled")
158160
rawSess.KillStmt()
@@ -201,7 +203,7 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe
201203
)
202204
}
203205

204-
sess, restoreSession, err := NewScanSession(rawSess, t.tbl, t.ExpireTime)
206+
sess, restoreSession, err := NewScanSession(scanCtx, rawSess, t.tbl, t.ExpireTime)
205207
if err != nil {
206208
return err
207209
}
@@ -242,11 +244,11 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe
242244
}
243245

244246
sqlStart := time.Now()
245-
rows, retryable, sqlErr := sess.ExecuteSQLWithCheck(ctx, sql)
247+
rows, retryable, sqlErr := sess.ExecuteSQLWithCheck(scanCtx, sql)
246248
selectInterval := time.Since(sqlStart)
247249
if sqlErr != nil {
248250
metrics.SelectErrorDuration.Observe(selectInterval.Seconds())
249-
needRetry := retryable && retryTimes < scanTaskExecuteSQLMaxRetry && ctx.Err() == nil && t.ctx.Err() == nil
251+
needRetry := retryable && retryTimes < scanTaskExecuteSQLMaxRetry && scanCtx.Err() == nil
250252
logutil.BgLogger().Warn("execute query for ttl scan task failed",
251253
zap.String("SQL", sql),
252254
zap.Int("retryTimes", retryTimes),
@@ -262,8 +264,8 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe
262264

263265
tracer.EnterPhase(metrics.PhaseWaitRetry)
264266
select {
265-
case <-ctx.Done():
266-
return ctx.Err()
267+
case <-scanCtx.Done():
268+
return scanCtx.Err()
267269
case <-time.After(scanTaskExecuteSQLRetryInterval):
268270
}
269271
tracer.EnterPhase(metrics.PhaseOther)
@@ -289,8 +291,8 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe
289291

290292
tracer.EnterPhase(metrics.PhaseDispatch)
291293
select {
292-
case <-ctx.Done():
293-
return ctx.Err()
294+
case <-scanCtx.Done():
295+
return scanCtx.Err()
294296
case delCh <- delTask:
295297
t.statistics.IncTotalRows(len(lastResult))
296298
}

pkg/ttl/ttlworker/scan_integration_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import (
2323
"time"
2424

2525
"github.com/pingcap/tidb/pkg/parser/ast"
26+
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
2627
"github.com/pingcap/tidb/pkg/testkit"
28+
"github.com/pingcap/tidb/pkg/testkit/testfailpoint"
2729
"github.com/pingcap/tidb/pkg/testkit/testflag"
2830
"github.com/pingcap/tidb/pkg/ttl/cache"
2931
"github.com/pingcap/tidb/pkg/ttl/ttlworker"
@@ -97,3 +99,87 @@ func TestCancelWhileScan(t *testing.T) {
9799
close(delCh)
98100
wg.Wait()
99101
}
102+
103+
func TestCancelWhileScanAtStatementBoundary(t *testing.T) {
104+
store, dom := testkit.CreateMockStoreAndDomain(t)
105+
tk := testkit.NewTestKit(t, store)
106+
107+
origBatchSize := vardef.TTLScanBatchSize.Load()
108+
vardef.TTLScanBatchSize.Store(30)
109+
t.Cleanup(func() {
110+
vardef.TTLScanBatchSize.Store(origBatchSize)
111+
})
112+
113+
tk.MustExec("create table test.t (id int primary key, created_at datetime) TTL= created_at + interval 1 hour")
114+
tk.MustExec("split table test.t between (0) and (30000) regions 30")
115+
for i := range 30 {
116+
tk.MustExec(fmt.Sprintf("insert into test.t values (%d, NOW() - INTERVAL 24 HOUR)", i*1000))
117+
}
118+
testTable, err := dom.InfoSchema().TableByName(context.Background(), ast.NewCIStr("test"), ast.NewCIStr("t"))
119+
require.NoError(t, err)
120+
testPhysicalTableCache, err := cache.NewPhysicalTable(ast.NewCIStr("test"), testTable.Meta(), ast.NewCIStr(""))
121+
require.NoError(t, err)
122+
123+
testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/store/copr/sleepCoprRequest", "return(2000)")
124+
125+
taskCtx, cancelTask := context.WithCancel(context.Background())
126+
defer cancelTask()
127+
ttlTask := ttlworker.NewTTLScanTask(taskCtx, testPhysicalTableCache, &cache.TTLTask{
128+
JobID: "test",
129+
TableID: testTable.Meta().ID,
130+
ScanID: 1,
131+
ScanRangeStart: nil,
132+
ScanRangeEnd: nil,
133+
ExpireTime: time.Now().Add(-12 * time.Hour),
134+
OwnerID: "test",
135+
OwnerAddr: "test",
136+
OwnerHBTime: time.Now(),
137+
Status: cache.TaskStatusRunning,
138+
StatusUpdateTime: time.Now(),
139+
State: &cache.TTLTaskState{},
140+
CreatedTime: time.Now(),
141+
})
142+
143+
triggerCancel := make(chan struct{})
144+
var cancelOnce sync.Once
145+
testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/executor/beforeResetSQLKillerForTTLScan", func(stmt ast.StmtNode) {
146+
if _, ok := stmt.(*ast.SelectStmt); !ok {
147+
return
148+
}
149+
150+
cancelOnce.Do(func() {
151+
cancelTask()
152+
close(triggerCancel)
153+
time.Sleep(100 * time.Millisecond)
154+
})
155+
})
156+
157+
delCh := make(chan *ttlworker.TTLDeleteTask)
158+
doneCh := make(chan struct{})
159+
go func() {
160+
defer close(doneCh)
161+
for range delCh {
162+
}
163+
}()
164+
165+
doScanDone := make(chan struct{})
166+
go func() {
167+
defer close(doScanDone)
168+
ttlTask.DoScan(context.Background(), delCh, dom.AdvancedSysSessionPool())
169+
}()
170+
171+
select {
172+
case <-triggerCancel:
173+
case <-time.After(10 * time.Second):
174+
require.FailNow(t, "TTL scan SELECT was not reached")
175+
}
176+
177+
select {
178+
case <-doScanDone:
179+
case <-time.After(time.Second):
180+
require.FailNow(t, "TTL scan was not canceled within 1s after statement-boundary cancel")
181+
}
182+
183+
close(delCh)
184+
<-doneCh
185+
}

pkg/ttl/ttlworker/session.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ func newTableSession(se session.Session, tbl *cache.PhysicalTable, expire time.T
196196
}
197197

198198
// NewScanSession creates a session for scan
199-
func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Time) (*ttlTableSession, func() error, error) {
199+
func NewScanSession(ctx context.Context, se session.Session, tbl *cache.PhysicalTable, expire time.Time) (*ttlTableSession, func() error, error) {
200200
origConcurrency := se.GetSessionVars().DistSQLScanConcurrency()
201201
origPaging := se.GetSessionVars().EnablePaging
202202
se.GetSessionVars().InternalSQLScanUserTable = true
@@ -218,7 +218,7 @@ func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Ti
218218
}
219219

220220
// Set the distsql scan concurrency to 1 to reduce the number of cop tasks in TTL scan.
221-
if _, err := se.ExecuteSQL(context.Background(), "set @@tidb_distsql_scan_concurrency=1"); err != nil {
221+
if _, err := se.ExecuteSQL(ctx, "set @@tidb_distsql_scan_concurrency=1"); err != nil {
222222
terror.Log(restore())
223223
return nil, nil, err
224224
}
@@ -227,7 +227,7 @@ func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Ti
227227
// If `tidb_enable_paging` is enabled, it may have multiple cop tasks even in one region that makes some extra
228228
// processed keys in TiKV side, see issue: https://github.com/pingcap/tidb/issues/58342.
229229
// Disable it to make the scan more efficient.
230-
if _, err := se.ExecuteSQL(context.Background(), "set @@tidb_enable_paging=OFF"); err != nil {
230+
if _, err := se.ExecuteSQL(ctx, "set @@tidb_enable_paging=OFF"); err != nil {
231231
terror.Log(restore())
232232
return nil, nil, err
233233
}

pkg/ttl/ttlworker/session_integration_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ func TestNewScanSession(t *testing.T) {
380380
called := false
381381
require.NoError(t, ttlworker.WithSessionForTest(pool, func(se session.Session) error {
382382
require.False(t, called)
383-
tblSe, restore, err := ttlworker.NewScanSession(se, &cache.PhysicalTable{}, time.Now())
383+
tblSe, restore, err := ttlworker.NewScanSession(context.Background(), se, &cache.PhysicalTable{}, time.Now())
384384
called = true
385385
if errSQL == "" {
386386
// success case
@@ -424,7 +424,7 @@ func TestNewScanSession(t *testing.T) {
424424
}, newFaultAfterCount(0)))
425425
require.NoError(t, ttlworker.WithSessionForTest(pool, func(se session.Session) error {
426426
require.False(t, called)
427-
tblSe, restore, err := ttlworker.NewScanSession(se, &cache.PhysicalTable{}, time.Now())
427+
tblSe, restore, err := ttlworker.NewScanSession(context.Background(), se, &cache.PhysicalTable{}, time.Now())
428428
called = true
429429
require.NoError(t, err)
430430
require.NotNil(t, tblSe)

0 commit comments

Comments
 (0)