-
Notifications
You must be signed in to change notification settings - Fork 4
feat(go): use load data statement for bulk ingestion #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,15 +19,20 @@ import ( | |
| "database/sql" | ||
| "errors" | ||
| "fmt" | ||
| "io" | ||
| "strings" | ||
| "sync/atomic" | ||
|
|
||
| "github.com/adbc-drivers/driverbase-go/driverbase" | ||
| sqlwrapper "github.com/adbc-drivers/driverbase-go/sqlwrapper" | ||
| "github.com/apache/arrow-adbc/go/adbc" | ||
| "github.com/apache/arrow-go/v18/arrow" | ||
| "github.com/apache/arrow-go/v18/arrow/array" | ||
| gomysql "github.com/go-sql-driver/mysql" | ||
| ) | ||
|
|
||
| var loadReaderCounter atomic.Uint64 | ||
|
|
||
| const ( | ||
| // Default num of rows per batch for batched INSERT | ||
| MySQLDefaultIngestBatchSize = 1000 | ||
|
|
@@ -214,13 +219,17 @@ func (c *mysqlConnectionImpl) GetPlaceholder(field *arrow.Field, index int) stri | |
| // Ensure mysqlConnectionImpl implements BulkIngester | ||
| var _ sqlwrapper.BulkIngester = (*mysqlConnectionImpl)(nil) | ||
|
|
||
| // ExecuteBulkIngest performs MySQL bulk ingest using batched INSERT statements. | ||
| // ExecuteBulkIngest performs MySQL bulk ingest using LOAD DATA LOCAL INFILE with a fallback to batched INSERTs. | ||
| func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) { | ||
| schema := stream.Schema() | ||
| if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options); err != nil { | ||
| return -1, c.ErrorHelper.WrapIO(err, "failed to create table") | ||
| } | ||
|
|
||
| if c.isLoadDataEnabled(ctx, conn) { | ||
| return c.executeLoadDataIngest(ctx, conn, options, stream) | ||
| } | ||
|
|
||
| // Validate MySQL-specific options | ||
| if options.MaxQuerySizeBytes > 0 { | ||
| return -1, c.ErrorHelper.InvalidArgument( | ||
|
|
@@ -246,6 +255,73 @@ func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwr | |
| ) | ||
| } | ||
|
|
||
| // isLoadDataEnabled checks if LOAD DATA LOCAL INFILE is enabled on the server. | ||
| func (c *mysqlConnectionImpl) isLoadDataEnabled(ctx context.Context, conn *sqlwrapper.LoggingConn) bool { | ||
| var localInfile int | ||
| err := conn.QueryRowContext(ctx, "SELECT @@local_infile").Scan(&localInfile) | ||
| return err == nil && localInfile == 1 | ||
| } | ||
|
|
||
| // executeLoadDataIngest performs bulk ingestion using the LOAD DATA LOCAL INFILE command. | ||
| func (c *mysqlConnectionImpl) executeLoadDataIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (int64, error) { | ||
| r, w := io.Pipe() | ||
| readerId := loadReaderCounter.Add(1) | ||
| readerName := fmt.Sprintf("adbc_ingest_%s_%d", options.TableName, readerId) | ||
|
|
||
| gomysql.RegisterReaderHandler(readerName, func() io.Reader { | ||
| return r | ||
| }) | ||
| defer gomysql.DeregisterReaderHandler(readerName) | ||
| batchSize := options.IngestBatchSize | ||
| if batchSize <= 0 { | ||
| batchSize = 10000 // Default batch size for streaming chunks | ||
| } | ||
|
|
||
| it, err := sqlwrapper.NewRowBufferIterator(stream, batchSize, c.TypeConverter) | ||
| if err != nil { | ||
| return -1, c.ErrorHelper.WrapIO(err, "failed to create row buffer iterator") | ||
| } | ||
|
|
||
| numCols := len(stream.Schema().Fields()) | ||
| go func() { | ||
| config := CSVConfig{ | ||
| FieldDelimiter: '\t', | ||
| LineTerminator: '\n', | ||
| NullValue: "\\N", | ||
| EscapeBackslash: true, | ||
| } | ||
| err := arrowToCSV(ctx, w, it, numCols, config) | ||
| if err != nil { | ||
| _ = w.CloseWithError(err) | ||
| } else { | ||
| _ = w.Close() | ||
| } | ||
| }() | ||
|
|
||
| var colNames []string | ||
| for _, field := range stream.Schema().Fields() { | ||
| colNames = append(colNames, quoteIdentifier(field.Name)) | ||
| } | ||
| colsList := strings.Join(colNames, ", ") | ||
|
|
||
| query := fmt.Sprintf( | ||
| "LOAD DATA LOCAL INFILE 'Reader::%s' INTO TABLE %s CHARACTER SET utf8mb4 FIELDS TERMINATED BY '\\t' ESCAPED BY '\\\\' LINES TERMINATED BY '\\n' (%s)", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| readerName, c.QuoteIdentifier(options.TableName), colsList, | ||
| ) | ||
|
|
||
| res, err := conn.ExecContext(ctx, query) | ||
| if err != nil { | ||
| return -1, c.ErrorHelper.WrapIO(err, "failed to execute LOAD DATA statement") | ||
| } | ||
|
Comment on lines
+312
to
+315
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If ExecContext fail, the functions returns but the CSV writer goroutine keeps running and will block forever when the pipe buffer fills which could result in resource leaks. We can add |
||
|
|
||
| rowCount, err := res.RowsAffected() | ||
| if err != nil { | ||
| return -1, c.ErrorHelper.WrapIO(err, "failed to get rows affected") | ||
| } | ||
|
|
||
| return rowCount, nil | ||
| } | ||
|
|
||
| // createTableIfNeeded creates the table based on the ingest mode | ||
| func (c *mysqlConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions) error { | ||
| switch options.Mode { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,107 @@ | ||||||
| // Copyright (c) 2025 ADBC Drivers Contributors | ||||||
| // | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| // you may not use this file except in compliance with the License. | ||||||
| // You may obtain a copy of the License at | ||||||
| // | ||||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| // | ||||||
| // Unless required by applicable law or agreed to in writing, software | ||||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| // See the License for the specific language governing permissions and | ||||||
| // limitations under the License. | ||||||
|
|
||||||
| package mysql | ||||||
|
|
||||||
| import ( | ||||||
| "context" | ||||||
| "fmt" | ||||||
| "io" | ||||||
| "strings" | ||||||
|
|
||||||
| "github.com/adbc-drivers/driverbase-go/sqlwrapper" | ||||||
| ) | ||||||
|
|
||||||
| // CSVConfig defines the configuration for Arrow-to-CSV/TSV conversion. | ||||||
| type CSVConfig struct { | ||||||
| FieldDelimiter byte | ||||||
| LineTerminator byte | ||||||
| NullValue string | ||||||
| EscapeBackslash bool | ||||||
| } | ||||||
|
|
||||||
| // arrowToCSV reads from a RowBufferIterator and streams data in CSV/TSV format into the provided io.Writer. | ||||||
| func arrowToCSV(ctx context.Context, w io.Writer, it *sqlwrapper.RowBufferIterator, numCols int, config CSVConfig) error { | ||||||
| var buf strings.Builder | ||||||
|
|
||||||
| for it.Next() { | ||||||
| select { | ||||||
| case <-ctx.Done(): | ||||||
| return ctx.Err() | ||||||
| default: | ||||||
| } | ||||||
|
|
||||||
| rows, rowCount := it.CurrentBatch() | ||||||
|
|
||||||
| buf.Reset() | ||||||
| for rowIdx := 0; rowIdx < rowCount; rowIdx++ { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| for colIdx := 0; colIdx < numCols; colIdx++ { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if colIdx > 0 { | ||||||
| buf.WriteByte(config.FieldDelimiter) | ||||||
| } | ||||||
|
|
||||||
| val := rows[rowIdx*numCols+colIdx] | ||||||
| buf.WriteString(formatValueForCSV(val, config)) | ||||||
| } | ||||||
| buf.WriteByte(config.LineTerminator) | ||||||
| } | ||||||
| if _, err := io.WriteString(w, buf.String()); err != nil { | ||||||
| return fmt.Errorf("failed to write batch to pipe: %w", err) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| return it.Err() | ||||||
| } | ||||||
|
|
||||||
| // escapeCSV escapes special characters based on the provided CSVConfig. | ||||||
| func escapeCSV(s string, config CSVConfig) string { | ||||||
| if config.EscapeBackslash { | ||||||
| s = strings.ReplaceAll(s, "\\", "\\\\") | ||||||
| s = strings.ReplaceAll(s, "\b", "\\b") | ||||||
| s = strings.ReplaceAll(s, "\x1a", "\\Z") | ||||||
| s = strings.ReplaceAll(s, "\x00", "\\0") | ||||||
| // Always escape \r if we are escaping backslashes, as it's a common special char | ||||||
| s = strings.ReplaceAll(s, "\r", "\\r") | ||||||
| } | ||||||
|
|
||||||
| if config.FieldDelimiter == '\t' { | ||||||
| s = strings.ReplaceAll(s, "\t", "\\t") | ||||||
| } | ||||||
| if config.LineTerminator == '\n' { | ||||||
| s = strings.ReplaceAll(s, "\n", "\\n") | ||||||
| } | ||||||
|
|
||||||
| return s | ||||||
| } | ||||||
|
|
||||||
| // formatValueForCSV converts a Go interface{} to a string suitable for CSV/TSV, handling escaping. | ||||||
| func formatValueForCSV(val any, config CSVConfig) string { | ||||||
| if val == nil { | ||||||
| return config.NullValue | ||||||
| } | ||||||
|
|
||||||
| switch v := val.(type) { | ||||||
| case string: | ||||||
| return escapeCSV(v, config) | ||||||
| case []byte: | ||||||
| return escapeCSV(string(v), config) | ||||||
| case bool: | ||||||
| if v { | ||||||
| return "1" | ||||||
| } | ||||||
| return "0" | ||||||
| default: | ||||||
| return fmt.Sprintf("%v", v) | ||||||
| } | ||||||
| } | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be a reader name injection/quoting risk.
options.TableNameis user-controlled and if the table name contains quotes or special characters, the query breaks. We can probably just use the counter without the user input (table name).