Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion go/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ import (
"database/sql"
"errors"
"fmt"
"io"
"strings"
"sync/atomic"

"github.com/adbc-drivers/driverbase-go/driverbase"
sqlwrapper "github.com/adbc-drivers/driverbase-go/sqlwrapper"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
gomysql "github.com/go-sql-driver/mysql"
)

var loadReaderCounter atomic.Uint64

const (
// Default num of rows per batch for batched INSERT
MySQLDefaultIngestBatchSize = 1000
Expand Down Expand Up @@ -214,13 +219,17 @@ func (c *mysqlConnectionImpl) GetPlaceholder(field *arrow.Field, index int) stri
// Ensure mysqlConnectionImpl implements BulkIngester
var _ sqlwrapper.BulkIngester = (*mysqlConnectionImpl)(nil)

// ExecuteBulkIngest performs MySQL bulk ingest using batched INSERT statements.
// ExecuteBulkIngest performs MySQL bulk ingest using LOAD DATA LOCAL INFILE with a fallback to batched INSERTs.
func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) {
schema := stream.Schema()
if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options); err != nil {
return -1, c.ErrorHelper.WrapIO(err, "failed to create table")
}

if c.isLoadDataEnabled(ctx, conn) {
return c.executeLoadDataIngest(ctx, conn, options, stream)
}

// Validate MySQL-specific options
if options.MaxQuerySizeBytes > 0 {
return -1, c.ErrorHelper.InvalidArgument(
Expand All @@ -246,6 +255,73 @@ func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwr
)
}

// isLoadDataEnabled checks if LOAD DATA LOCAL INFILE is enabled on the server.
func (c *mysqlConnectionImpl) isLoadDataEnabled(ctx context.Context, conn *sqlwrapper.LoggingConn) bool {
var localInfile int
err := conn.QueryRowContext(ctx, "SELECT @@local_infile").Scan(&localInfile)
return err == nil && localInfile == 1
}

// executeLoadDataIngest performs bulk ingestion using the LOAD DATA LOCAL INFILE command.
func (c *mysqlConnectionImpl) executeLoadDataIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (int64, error) {
r, w := io.Pipe()
readerId := loadReaderCounter.Add(1)
readerName := fmt.Sprintf("adbc_ingest_%s_%d", options.TableName, readerId)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a reader name injection/quoting risk. options.TableName is user-controlled and if the table name contains quotes or special characters, the query breaks. We can probably just use the counter without the user input (table name).


gomysql.RegisterReaderHandler(readerName, func() io.Reader {
return r
})
defer gomysql.DeregisterReaderHandler(readerName)
batchSize := options.IngestBatchSize
if batchSize <= 0 {
batchSize = 10000 // Default batch size for streaming chunks
}

it, err := sqlwrapper.NewRowBufferIterator(stream, batchSize, c.TypeConverter)
if err != nil {
return -1, c.ErrorHelper.WrapIO(err, "failed to create row buffer iterator")
}

numCols := len(stream.Schema().Fields())
go func() {
config := CSVConfig{
FieldDelimiter: '\t',
LineTerminator: '\n',
NullValue: "\\N",
EscapeBackslash: true,
}
err := arrowToCSV(ctx, w, it, numCols, config)
if err != nil {
_ = w.CloseWithError(err)
} else {
_ = w.Close()
}
}()

var colNames []string
for _, field := range stream.Schema().Fields() {
colNames = append(colNames, quoteIdentifier(field.Name))
}
colsList := strings.Join(colNames, ", ")

query := fmt.Sprintf(
"LOAD DATA LOCAL INFILE 'Reader::%s' INTO TABLE %s CHARACTER SET utf8mb4 FIELDS TERMINATED BY '\\t' ESCAPED BY '\\\\' LINES TERMINATED BY '\\n' (%s)",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHARACTER SET utf8mb4 makes MySQL treat the incoming stream as UTF‑8. If we ever ingest non‑UTF8 bytes (or binary-ish data), could this cause conversion issues? Do we need this here, or should we drop/make it optional?

readerName, c.QuoteIdentifier(options.TableName), colsList,
)

res, err := conn.ExecContext(ctx, query)
if err != nil {
return -1, c.ErrorHelper.WrapIO(err, "failed to execute LOAD DATA statement")
}
Comment on lines +312 to +315
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ExecContext fail, the functions returns but the CSV writer goroutine keeps running and will block forever when the pipe buffer fills which could result in resource leaks. We can add r.Close() before returning to unblock the writer, and wrap the goroutine with a cancellable context.


rowCount, err := res.RowsAffected()
if err != nil {
return -1, c.ErrorHelper.WrapIO(err, "failed to get rows affected")
}

return rowCount, nil
}

// createTableIfNeeded creates the table based on the ingest mode
func (c *mysqlConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions) error {
switch options.Mode {
Expand Down
107 changes: 107 additions & 0 deletions go/csv_helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) 2025 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql

import (
"context"
"fmt"
"io"
"strings"

"github.com/adbc-drivers/driverbase-go/sqlwrapper"
)

// CSVConfig defines the configuration for Arrow-to-CSV/TSV conversion.
type CSVConfig struct {
FieldDelimiter byte
LineTerminator byte
NullValue string
EscapeBackslash bool
}

// arrowToCSV reads from a RowBufferIterator and streams data in CSV/TSV format into the provided io.Writer.
func arrowToCSV(ctx context.Context, w io.Writer, it *sqlwrapper.RowBufferIterator, numCols int, config CSVConfig) error {
var buf strings.Builder

for it.Next() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

rows, rowCount := it.CurrentBatch()

buf.Reset()
for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
for rowIdx := range rowCount {

for colIdx := 0; colIdx < numCols; colIdx++ {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for colIdx := 0; colIdx < numCols; colIdx++ {
for colIdx := range numCols {

if colIdx > 0 {
buf.WriteByte(config.FieldDelimiter)
}

val := rows[rowIdx*numCols+colIdx]
buf.WriteString(formatValueForCSV(val, config))
}
buf.WriteByte(config.LineTerminator)
}
if _, err := io.WriteString(w, buf.String()); err != nil {
return fmt.Errorf("failed to write batch to pipe: %w", err)
}
}

return it.Err()
}

// escapeCSV escapes special characters based on the provided CSVConfig.
func escapeCSV(s string, config CSVConfig) string {
if config.EscapeBackslash {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\b", "\\b")
s = strings.ReplaceAll(s, "\x1a", "\\Z")
s = strings.ReplaceAll(s, "\x00", "\\0")
// Always escape \r if we are escaping backslashes, as it's a common special char
s = strings.ReplaceAll(s, "\r", "\\r")
}

if config.FieldDelimiter == '\t' {
s = strings.ReplaceAll(s, "\t", "\\t")
}
if config.LineTerminator == '\n' {
s = strings.ReplaceAll(s, "\n", "\\n")
}

return s
}

// formatValueForCSV converts a Go interface{} to a string suitable for CSV/TSV, handling escaping.
func formatValueForCSV(val any, config CSVConfig) string {
if val == nil {
return config.NullValue
}

switch v := val.(type) {
case string:
return escapeCSV(v, config)
case []byte:
return escapeCSV(string(v), config)
case bool:
if v {
return "1"
}
return "0"
default:
return fmt.Sprintf("%v", v)
}
}
Loading
Loading