Skip to content

Commit 6cc2e5b

Browse files
committed
update multipart upload progress using reader
1 parent e511524 commit 6cc2e5b

2 files changed

Lines changed: 145 additions & 8 deletions

File tree

internal/turso/tursoServer.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type progressReader struct {
2121
reader io.Reader
2222
totalSize int64
2323
bytesRead int64
24+
baseBytes int64 // Bytes already uploaded before progressReader started
2425
startTime time.Time
2526
onProgress func(progressPct int, uploadedBytes int64, totalBytes int64, elapsedTime time.Duration, done bool)
2627
lastUpdate int // Last reported progress percentage. Initially -1 to ensure first update is always sent.
@@ -30,13 +31,14 @@ func (pr *progressReader) Read(p []byte) (int, error) {
3031
n, err := pr.reader.Read(p)
3132
if n > 0 {
3233
pr.bytesRead += int64(n)
33-
progressPct := int(float64(pr.bytesRead) / float64(pr.totalSize) * 100)
34+
totalUploaded := pr.baseBytes + pr.bytesRead
35+
progressPct := int(float64(totalUploaded) / float64(pr.totalSize) * 100)
3436

3537
// Only call progress if we've made at least 1% progress or if we're done
3638
if progressPct > pr.lastUpdate || errors.Is(err, io.EOF) {
3739
elapsedTime := time.Since(pr.startTime)
3840
pr.lastUpdate = progressPct
39-
pr.onProgress(progressPct, pr.bytesRead, pr.totalSize, elapsedTime, errors.Is(err, io.EOF))
41+
pr.onProgress(progressPct, totalUploaded, pr.totalSize, elapsedTime, errors.Is(err, io.EOF))
4042
}
4143
}
4244
return n, err
@@ -185,6 +187,7 @@ func (i *TursoServerClient) startMultipartUpload(dbSize int64) (int64, error) {
185187
func (i *TursoServerClient) uploadChunks(chunkSize int64, file io.Reader, totalSize int64, startTime time.Time, remoteEncryptionCipher, remoteEncryptionKey string, onUploadProgress func(progressPct int, uploadedBytes int64, totalBytes int64, elapsedTime time.Duration, done bool)) (int64, error) {
186188
var uploadedBytes int64 = 0
187189
chunkID := 0
190+
lastProgressPct := -1
188191

189192
for uploadedBytes < totalSize {
190193
remaining := totalSize - uploadedBytes
@@ -194,6 +197,16 @@ func (i *TursoServerClient) uploadChunks(chunkSize int64, file io.Reader, totalS
194197
}
195198

196199
chunkReader := io.LimitReader(file, currentChunkSize)
200+
201+
progressTracker := &progressReader{
202+
reader: chunkReader,
203+
totalSize: totalSize,
204+
baseBytes: uploadedBytes,
205+
startTime: startTime,
206+
onProgress: onUploadProgress,
207+
lastUpdate: lastProgressPct,
208+
}
209+
197210
chunkPath := fmt.Sprintf("/v2/upload/chunk/%d", chunkID)
198211

199212
var headers = map[string]string{}
@@ -203,7 +216,7 @@ func (i *TursoServerClient) uploadChunks(chunkSize int64, file io.Reader, totalS
203216
}
204217
headers["Content-Length"] = strconv.FormatInt(currentChunkSize, 10)
205218

206-
r, err := i.client.PutBinary(chunkPath, chunkReader, headers)
219+
r, err := i.client.PutBinary(chunkPath, progressTracker, headers)
207220
if err != nil {
208221
return 0, fmt.Errorf("failed to upload chunk %d: %w", chunkID, err)
209222
}
@@ -221,10 +234,7 @@ func (i *TursoServerClient) uploadChunks(chunkSize int64, file io.Reader, totalS
221234
}
222235

223236
uploadedBytes += currentChunkSize
224-
progressPct := int(float64(uploadedBytes) / float64(totalSize) * 100)
225-
elapsedTime := time.Since(startTime)
226-
//TODO update progress more smoothly than after every chunk
227-
onUploadProgress(progressPct, uploadedBytes, totalSize, elapsedTime, false)
237+
lastProgressPct = progressTracker.lastUpdate
228238

229239
chunkID++
230240
}

internal/turso/tursoServer_test.go

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ func createTestFileWithContent(t *testing.T, content []byte) string {
175175
func createTestFile(t *testing.T, size int64) string {
176176
t.Helper()
177177

178-
// Generate deterministic data pattern
179178
pattern := []byte("TESTDATA")
180179
content := make([]byte, size)
181180
for i := int64(0); i < size; i++ {
@@ -926,3 +925,131 @@ func TestProgressReader_BytesReadAccurate(t *testing.T) {
926925

927926
require.Equal(t, int64(500), lastUploadedBytes)
928927
}
928+
929+
func TestProgressReader_WithBaseBytes(t *testing.T) {
930+
// Simulate reading from a second chunk of a 200-byte total upload
931+
// where the first chunk (100 bytes) has already been uploaded.
932+
data := strings.Repeat("x", 100) // Second chunk: 100 bytes
933+
reader := strings.NewReader(data)
934+
935+
var progressUpdates []struct {
936+
pct int
937+
uploaded int64
938+
total int64
939+
}
940+
941+
pr := &progressReader{
942+
reader: reader,
943+
totalSize: 200,
944+
baseBytes: 100,
945+
startTime: time.Now(),
946+
onProgress: func(progressPct int, uploadedBytes int64, totalBytes int64, elapsedTime time.Duration, done bool) {
947+
progressUpdates = append(progressUpdates, struct {
948+
pct int
949+
uploaded int64
950+
total int64
951+
}{progressPct, uploadedBytes, totalBytes})
952+
},
953+
lastUpdate: 50,
954+
}
955+
956+
buf := make([]byte, 10)
957+
for {
958+
_, err := pr.Read(buf)
959+
if err == io.EOF {
960+
break
961+
}
962+
if err != nil {
963+
t.Fatalf("unexpected error: %v", err)
964+
}
965+
}
966+
967+
require.NotEmpty(t, progressUpdates, "expected progress updates")
968+
969+
firstUpdate := progressUpdates[0]
970+
require.Greater(t, firstUpdate.pct, 50, "first progress update should be > 50%%")
971+
972+
for i, update := range progressUpdates {
973+
require.Greater(t, update.uploaded, int64(100), "update %d: uploadedBytes should be > 100 (baseBytes)", i)
974+
require.Equal(t, int64(200), update.total, "update %d: totalBytes should be 200", i)
975+
}
976+
977+
lastUpdate := progressUpdates[len(progressUpdates)-1]
978+
require.Equal(t, 100, lastUpdate.pct, "final progress should be 100%%")
979+
require.Equal(t, int64(200), lastUpdate.uploaded, "final uploadedBytes should be 200")
980+
}
981+
982+
func TestProgressReader_CumulativeAcrossChunks(t *testing.T) {
983+
totalSize := int64(300)
984+
chunkSize := int64(100)
985+
986+
var allUpdates []struct {
987+
pct int
988+
uploaded int64
989+
}
990+
991+
// Simulate reading 3 chunks
992+
var baseBytes int64 = 0
993+
lastPct := -1
994+
995+
for chunk := 0; chunk < 3; chunk++ {
996+
data := bytes.Repeat([]byte("x"), int(chunkSize))
997+
reader := bytes.NewReader(data)
998+
999+
pr := &progressReader{
1000+
reader: reader,
1001+
totalSize: totalSize,
1002+
baseBytes: baseBytes,
1003+
startTime: time.Now(),
1004+
onProgress: func(progressPct int, uploadedBytes int64, totalBytes int64, elapsedTime time.Duration, done bool) {
1005+
allUpdates = append(allUpdates, struct {
1006+
pct int
1007+
uploaded int64
1008+
}{progressPct, uploadedBytes})
1009+
},
1010+
lastUpdate: lastPct,
1011+
}
1012+
1013+
// Read all of this chunk
1014+
buf := make([]byte, 10)
1015+
for {
1016+
_, err := pr.Read(buf)
1017+
if err == io.EOF {
1018+
break
1019+
}
1020+
require.NoError(t, err)
1021+
}
1022+
1023+
baseBytes += chunkSize
1024+
lastPct = pr.lastUpdate
1025+
}
1026+
1027+
require.NotEmpty(t, allUpdates, "expected progress updates")
1028+
require.LessOrEqual(t, allUpdates[0].pct, 10, "first update should be low")
1029+
require.Equal(t, 100, allUpdates[len(allUpdates)-1].pct, "last update should be 100%%")
1030+
for i := 1; i < len(allUpdates); i++ {
1031+
require.GreaterOrEqual(t, allUpdates[i].pct, allUpdates[i-1].pct, "progress went backwards: %d%% -> %d%%", allUpdates[i-1].pct, allUpdates[i].pct)
1032+
}
1033+
for i := 1; i < len(allUpdates); i++ {
1034+
require.GreaterOrEqual(t, allUpdates[i].uploaded, allUpdates[i-1].uploaded, "uploadedBytes went backwards: %d -> %d", allUpdates[i-1].uploaded, allUpdates[i].uploaded)
1035+
}
1036+
}
1037+
1038+
func TestUploadFileMultipart_SmoothProgress(t *testing.T) {
1039+
mock := NewMockTursoServer()
1040+
mock.chunkSize = 100 * 1024
1041+
defer mock.Close()
1042+
1043+
client := createTestClient(t, mock.URL)
1044+
testFile := createTestFile(t, 1024*1024)
1045+
progress := NewProgressRecorder()
1046+
1047+
err := client.UploadFileMultipart(testFile, "", "", progress.Callback())
1048+
require.NoError(t, err)
1049+
1050+
calls := progress.GetCalls()
1051+
1052+
require.Greater(t, len(calls), 15, "Expected smooth progress with many updates, got only %d", len(calls))
1053+
progress.VerifyProgressIncreasing(t)
1054+
require.True(t, calls[len(calls)-1].Done, "Final callback should have Done=true")
1055+
}

0 commit comments

Comments
 (0)