-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbatch.go
More file actions
65 lines (56 loc) · 2.01 KB
/
batch.go
File metadata and controls
65 lines (56 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package sequel
import (
"context"
"database/sql"
"fmt"
"slices"
"strings"
)
// Executor is the interface for types that can execute queries. Both [DB] and
// [Tx] satisfy this interface.
type Executor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
// BatchSize is the maximum number of records inserted per statement.
const BatchSize = 100
// Batch inserts a slice of items into the given table using multi-row INSERT
// statements. Items are inserted in chunks of [BatchSize]. The columns
// parameter specifies the column names, and the extractValues function maps each item to
// its column values. The length of the slice returned by extractValues must match the
// length of columns. Batch does nothing if items is empty. Table, columns and onConfict
// are not sanitized; they must come from a trusted source. The extractValues function will
// never be called concurrently.
func Batch[T any](ctx context.Context, exec Executor, table string, columns []string, onConflict string, items []T, extractValues func(T) []any) error {
batch := 0
for chunk := range slices.Chunk(items, BatchSize) {
query, args := batchQuery(table, columns, onConflict, chunk, extractValues)
if _, err := exec.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("batch %d (%d items) failed: %w", batch, len(chunk), err)
}
batch++
}
return nil
}
func batchQuery[T any](table string, columns []string, onConflict string, items []T, extractValues func(T) []any) (string, []any) {
ncols := len(columns)
args := make([]any, 0, len(items)*ncols)
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s (%s) VALUES ", table, strings.Join(columns, ", "))
for i, item := range items {
if i > 0 {
b.WriteString(", ")
}
b.WriteByte('(')
vals := extractValues(item)
for j, v := range vals {
if j > 0 {
b.WriteString(", ")
}
fmt.Fprintf(&b, "$%d", i*ncols+j+1)
args = append(args, v)
}
b.WriteByte(')')
}
fmt.Fprintf(&b, " %s", onConflict)
return b.String(), args
}