Skip to content

Commit 8c1219a

Browse files
[INS-207] Add Role-Aware Resumption Support for Legacy S3 Scan (#4600)
* Added roles to legacy scan resumption * Test to verify legacy chunks resumption with roles.
1 parent 606a7ed commit 8c1219a

4 files changed

Lines changed: 248 additions & 3 deletions

File tree

pkg/sources/s3/checkpointer.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ func (p *Checkpointer) Reset() {
9898
type ResumeInfo struct {
9999
CurrentBucket string `json:"current_bucket"` // Current bucket being scanned
100100
StartAfter string `json:"start_after"` // Last processed object key
101+
Role string `json:"role"` // Role used for scanning
101102
}
102103

103104
// ResumePoint retrieves the last saved checkpoint state if one exists.
@@ -121,7 +122,7 @@ func (p *Checkpointer) ResumePoint(ctx context.Context) (ResumeInfo, error) {
121122
return resume, nil
122123
}
123124

124-
return ResumeInfo{CurrentBucket: resumeInfo.CurrentBucket, StartAfter: resumeInfo.StartAfter}, nil
125+
return ResumeInfo{CurrentBucket: resumeInfo.CurrentBucket, StartAfter: resumeInfo.StartAfter, Role: resumeInfo.Role}, nil
125126
}
126127

127128
// Complete marks the entire scanning operation as finished and clears the resume state.
@@ -215,7 +216,7 @@ func (p *Checkpointer) updateCheckpoint(bucket string, role string, lastKey stri
215216
return nil
216217
}
217218

218-
encoded, err := json.Marshal(&ResumeInfo{CurrentBucket: bucket, StartAfter: lastKey})
219+
encoded, err := json.Marshal(&ResumeInfo{CurrentBucket: bucket, StartAfter: lastKey, Role: role})
219220
if err != nil {
220221
return fmt.Errorf("failed to encode resume info: %w", err)
221222
}

pkg/sources/s3/checkpointer_test.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,56 @@ func TestCheckpointerResumption(t *testing.T) {
6161
assert.Equal(t, "key-11", finalResumeInfo.StartAfter)
6262
}
6363

64+
func TestCheckpointerResumptionWithRole(t *testing.T) {
65+
ctx := context.Background()
66+
67+
// First scan - process 6 objects then interrupt.
68+
initialProgress := &sources.Progress{}
69+
tracker := NewCheckpointer(ctx, initialProgress)
70+
role := "test-role"
71+
72+
firstPage := &s3.ListObjectsV2Output{
73+
Contents: make([]s3types.Object, 12), // Total of 12 objects
74+
}
75+
for i := range 12 {
76+
key := fmt.Sprintf("key-%d", i)
77+
firstPage.Contents[i] = s3types.Object{Key: &key}
78+
}
79+
80+
// Process first 6 objects.
81+
for i := range 6 {
82+
err := tracker.UpdateObjectCompletion(ctx, i, "test-bucket", role, firstPage.Contents)
83+
assert.NoError(t, err)
84+
}
85+
86+
// Verify resume info is set correctly.
87+
resumeInfo, err := tracker.ResumePoint(ctx)
88+
require.NoError(t, err)
89+
assert.Equal(t, "test-bucket", resumeInfo.CurrentBucket)
90+
assert.Equal(t, "key-5", resumeInfo.StartAfter)
91+
assert.Equal(t, role, resumeInfo.Role)
92+
93+
// Resume scan with existing progress.
94+
resumeTracker := NewCheckpointer(ctx, initialProgress)
95+
96+
resumePage := &s3.ListObjectsV2Output{
97+
Contents: firstPage.Contents[6:], // Remaining 6 objects
98+
}
99+
100+
// Process remaining objects.
101+
for i := range len(resumePage.Contents) {
102+
err := resumeTracker.UpdateObjectCompletion(ctx, i, "test-bucket", role, resumePage.Contents)
103+
assert.NoError(t, err)
104+
}
105+
106+
// Verify final resume info.
107+
finalResumeInfo, err := resumeTracker.ResumePoint(ctx)
108+
require.NoError(t, err)
109+
assert.Equal(t, "test-bucket", finalResumeInfo.CurrentBucket)
110+
assert.Equal(t, "key-11", finalResumeInfo.StartAfter)
111+
assert.Equal(t, role, finalResumeInfo.Role)
112+
}
113+
64114
func TestCheckpointerReset(t *testing.T) {
65115
tests := []struct {
66116
name string
@@ -111,6 +161,13 @@ func TestGetResumePoint(t *testing.T) {
111161
},
112162
expectedResumeInfo: ResumeInfo{CurrentBucket: "test-bucket", StartAfter: "test-key"},
113163
},
164+
{
165+
name: "valid resume info with role",
166+
progress: &sources.Progress{
167+
EncodedResumeInfo: `{"current_bucket":"test-bucket","start_after":"test-key","role":"test-role"}`,
168+
},
169+
expectedResumeInfo: ResumeInfo{CurrentBucket: "test-bucket", StartAfter: "test-key", Role: "test-role"},
170+
},
114171
{
115172
name: "empty encoded resume info",
116173
progress: &sources.Progress{EncodedResumeInfo: ""},
@@ -121,6 +178,13 @@ func TestGetResumePoint(t *testing.T) {
121178
EncodedResumeInfo: `{"current_bucket":"","start_after":"test-key"}`,
122179
},
123180
},
181+
{
182+
name: "no role in resume info",
183+
progress: &sources.Progress{
184+
EncodedResumeInfo: `{"current_bucket":"test-bucket","start_after":"test-key"}`,
185+
},
186+
expectedResumeInfo: ResumeInfo{CurrentBucket: "test-bucket", StartAfter: "test-key", Role: ""},
187+
},
124188
{
125189
name: "unmarshal error",
126190
progress: &sources.Progress{
@@ -257,6 +321,122 @@ func TestCheckpointerUpdate(t *testing.T) {
257321
})
258322
}
259323
}
324+
func TestCheckpointerUpdateWithRole(t *testing.T) {
325+
role := "test-role"
326+
tests := []struct {
327+
name string
328+
description string
329+
completedIdx int
330+
pageSize int
331+
preCompleted []int
332+
expectedKey string
333+
expectedRole string
334+
expectedLowestIncomplete int
335+
}{
336+
{
337+
name: "first object completed",
338+
description: "Basic case - completing first object",
339+
completedIdx: 0,
340+
pageSize: 3,
341+
expectedKey: "key-0",
342+
expectedRole: role,
343+
expectedLowestIncomplete: 1,
344+
},
345+
{
346+
name: "completing missing middle",
347+
description: "Completing object when previous is done",
348+
completedIdx: 1,
349+
pageSize: 3,
350+
preCompleted: []int{0},
351+
expectedKey: "key-1",
352+
expectedRole: role,
353+
expectedLowestIncomplete: 2,
354+
},
355+
{
356+
name: "all objects completed in order",
357+
description: "Completing final object in sequence",
358+
completedIdx: 2,
359+
pageSize: 3,
360+
preCompleted: []int{0, 1},
361+
expectedKey: "key-2",
362+
expectedRole: role,
363+
expectedLowestIncomplete: 3,
364+
},
365+
{
366+
name: "out of order completion before lowest",
367+
description: "Completing object before current lowest incomplete - should not affect checkpoint",
368+
completedIdx: 1,
369+
pageSize: 4,
370+
preCompleted: []int{0, 2, 3},
371+
expectedKey: "key-3",
372+
expectedRole: role,
373+
expectedLowestIncomplete: 4,
374+
},
375+
{
376+
name: "last index in max page",
377+
description: "Edge case - maximum page size boundary",
378+
completedIdx: 999,
379+
pageSize: 1000,
380+
preCompleted: func() []int {
381+
indices := make([]int, 999)
382+
for i := range indices {
383+
indices[i] = i
384+
}
385+
return indices
386+
}(),
387+
expectedKey: "key-999",
388+
expectedRole: role,
389+
expectedLowestIncomplete: 1000,
390+
},
391+
}
392+
393+
for _, tt := range tests {
394+
t.Run(tt.name, func(t *testing.T) {
395+
t.Parallel()
396+
397+
ctx := context.Background()
398+
progress := new(sources.Progress)
399+
tracker := &Checkpointer{
400+
progress: progress,
401+
completedObjects: make([]bool, tt.pageSize),
402+
completionOrder: make([]int, 0, tt.pageSize),
403+
lowestIncompleteIdx: 0,
404+
}
405+
406+
page := &s3.ListObjectsV2Output{Contents: make([]s3types.Object, tt.pageSize)}
407+
for i := range tt.pageSize {
408+
key := fmt.Sprintf("key-%d", i)
409+
page.Contents[i] = s3types.Object{Key: &key}
410+
}
411+
412+
// Setup pre-completed objects.
413+
for _, idx := range tt.preCompleted {
414+
tracker.completedObjects[idx] = true
415+
tracker.completionOrder = append(tracker.completionOrder, idx)
416+
}
417+
418+
// Find the correct lowest incomplete index after pre-completion.
419+
for i := range tt.pageSize {
420+
if !tracker.completedObjects[i] {
421+
tracker.lowestIncompleteIdx = i
422+
break
423+
}
424+
}
425+
426+
err := tracker.UpdateObjectCompletion(ctx, tt.completedIdx, "test-bucket", role, page.Contents)
427+
assert.NoError(t, err, "Unexpected error updating progress")
428+
429+
var info ResumeInfo
430+
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
431+
assert.NoError(t, err, "Failed to decode resume info")
432+
assert.Equal(t, tt.expectedKey, info.StartAfter, "Incorrect resume point")
433+
assert.Equal(t, tt.expectedRole, info.Role, "Incorrect role")
434+
435+
assert.Equal(t, tt.expectedLowestIncomplete, tracker.lowestIncompleteIdx,
436+
"Incorrect lowest incomplete index")
437+
})
438+
}
439+
}
260440

261441
func TestCheckpointerUpdateUnitScan(t *testing.T) {
262442
ctx := context.Background()

pkg/sources/s3/s3.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ type resumePosition struct {
245245
startAfter string // The last processed object key within the bucket
246246
isNewScan bool // True if we're starting a fresh scan
247247
exactMatch bool // True if we found the exact bucket we were previously processing
248+
role string // The role used during the previous scan
248249
}
249250

250251
// determineResumePosition calculates where to resume scanning from based on the last saved checkpoint
@@ -282,6 +283,7 @@ func determineResumePosition(ctx context.Context, tracker *Checkpointer, buckets
282283
startAfter: resumePoint.StartAfter,
283284
index: startIdx,
284285
exactMatch: found,
286+
role: resumePoint.Role,
285287
}
286288
}
287289

@@ -306,12 +308,14 @@ func (s *Source) scanBuckets(
306308
"Resume bucket no longer available, starting from closest position",
307309
"original_bucket", pos.bucket,
308310
"position", pos.index,
311+
"role", pos.role,
309312
)
310313
default:
311314
ctx.Logger().Info(
312315
"Resuming scan from previous scan's bucket",
313316
"bucket", pos.bucket,
314317
"position", pos.index,
318+
"role", pos.role,
315319
)
316320
}
317321

@@ -327,7 +331,7 @@ func (s *Source) scanBuckets(
327331
)
328332

329333
var startAfter *string
330-
if bucket == pos.bucket && pos.startAfter != "" {
334+
if bucket == pos.bucket && pos.startAfter != "" && role == pos.role {
331335
startAfter = &pos.startAfter
332336
ctx.Logger().V(3).Info(
333337
"Resuming bucket scan",

pkg/sources/s3/s3_integration_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,63 @@ func TestSource_ChunkUnit_Resumption_MultipleBucketsConcurrent(t *testing.T) {
584584
assert.Equal(t, wantCount, gotCount, "Chunk count mismatch for bucket %s", bucket)
585585
}
586586
}
587+
588+
func TestSourceChunksResumptionWithRole(t *testing.T) {
589+
t.Parallel()
590+
591+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
592+
defer cancel()
593+
secret, err := common.GetTestSecret(ctx)
594+
if err != nil {
595+
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
596+
}
597+
598+
s3key := secret.MustGetField("AWS_S3_KEY")
599+
s3secret := secret.MustGetField("AWS_S3_SECRET")
600+
601+
src := new(Source)
602+
src.Progress = sources.Progress{
603+
Message: "Bucket: integration-resumption-tests",
604+
EncodedResumeInfo: "{\"current_bucket\":\"integration-resumption-tests\",\"start_after\":\"test-dir/\",\"role\":\"arn:aws:iam::619888638459:role/s3-test-assume-role\"}",
605+
SectionsCompleted: 0,
606+
SectionsRemaining: 1,
607+
}
608+
connection := &sourcespb.S3{
609+
Credential: &sourcespb.S3_AccessKey{
610+
AccessKey: &credentialspb.KeySecret{
611+
Key: s3key,
612+
Secret: s3secret,
613+
},
614+
},
615+
Buckets: []string{"integration-resumption-tests"},
616+
Roles: []string{"arn:aws:iam::619888638459:role/s3-test-assume-role"},
617+
EnableResumption: true,
618+
}
619+
conn, err := anypb.New(connection)
620+
require.NoError(t, err)
621+
622+
err = src.Init(ctx, "test name", 0, 0, false, conn, 2)
623+
require.NoError(t, err)
624+
625+
chunksCh := make(chan *sources.Chunk)
626+
var count int
627+
628+
cancelCtx, ctxCancel := context.WithCancel(ctx)
629+
defer ctxCancel()
630+
631+
go func() {
632+
defer close(chunksCh)
633+
err = src.Chunks(cancelCtx, chunksCh)
634+
assert.NoError(t, err, "Should not error during scan")
635+
}()
636+
637+
for range chunksCh {
638+
count++
639+
}
640+
641+
// Verify that we processed all remaining data on resume.
642+
// Also verify that we processed less than the total number of chunks for the source.
643+
sourceTotalChunkCount := 19787
644+
assert.Equal(t, 9638, count, "Should have processed all remaining data on resume")
645+
assert.Less(t, count, sourceTotalChunkCount, "Should have processed less than total chunks on resume")
646+
}

0 commit comments

Comments
 (0)