Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions compatibility/flightsql/go/flightsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,36 @@ func executeQueryAndVerify(cnxn adbc.Connection, query string, expectedResults [
record := rows.Record()
numRows := record.NumRows()

id := record.Column(0).(*array.Int64)
name := record.Column(1).(*array.String)
value := record.Column(2).(*array.Int64)
for i := 0; i < int(numRows); i++ {
var id, value int64
switch idCol := record.Column(0).(type) {
case *array.Int32:
id = int64(idCol.Value(i))
case *array.Int64:
id = idCol.Value(i)
default:
t.Fatalf("unexpected type for id column: %T", record.Column(0))
}

name := record.Column(1).(*array.String)

switch valueCol := record.Column(2).(type) {
case *array.Int32:
value = int64(valueCol.Value(i))
case *array.Int64:
value = valueCol.Value(i)
default:
t.Fatalf("unexpected type for value column: %T", record.Column(2))
}

actualResults = append(actualResults, struct {
id int64
name string
value int64
}{
id: id.Value(i),
id: id,
name: name.Value(i),
value: value.Value(i),
value: value,
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion compatibility/flightsql/python/flightsql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def setUp(self):
value INT
)
""") # Create the table
self.conn.commit()

def test_insert_and_select(self):
"""Test inserting data and selecting it back to verify correctness."""
Expand All @@ -55,7 +56,6 @@ def test_drop_table(self):
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='intTable'") # Check if the table exists
rows = cursor.fetchall()
self.assertEqual(len(rows), 0, "Table 'intTable' should be dropped and not exist in the database.")
cursor.execute("COMMIT;")

if __name__ == "__main__":
unittest.main()
113 changes: 43 additions & 70 deletions flightsqlserver/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ import (
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/marcboeker/go-duckdb"

"google.golang.org/grpc"
_ "github.com/marcboeker/go-duckdb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
_ "modernc.org/sqlite"
)
Expand Down Expand Up @@ -210,19 +208,21 @@ func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) {

type Statement struct {
stmt *sql.Stmt
query string
params [][]interface{}
}

type SQLiteFlightSQLServer struct {
flightsql.BaseServer
db *sql.DB
db *sql.DB
conn *duckdb.Conn

prepared sync.Map
openTransactions sync.Map
}

func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) {
ret := &SQLiteFlightSQLServer{db: db}
func NewSQLiteFlightSQLServer(db *sql.DB, conn *duckdb.Conn) (*SQLiteFlightSQLServer, error) {
ret := &SQLiteFlightSQLServer{db: db, conn: conn}
ret.Alloc = memory.DefaultAllocator
for k, v := range SqlInfoResultMap() {
ret.RegisterSqlInfo(flightsql.SqlInfo(k), v)
Expand Down Expand Up @@ -260,17 +260,20 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq
if err != nil {
return nil, nil, err
}

var db dbQueryCtx = s.db
if txnid != "" {
tx, loaded := s.openTransactions.Load(txnid)
if !loaded {
return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid)
}
db = tx.(*sql.Tx)
return nil, nil, fmt.Errorf("transactions not yet supported with DuckDB")
}

return doGetQuery(ctx, s.Alloc, db, query, nil)
// var db dbQueryCtx = s.db
// if txnid != "" {
// tx, loaded := s.openTransactions.Load(txnid)
// if !loaded {
// return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid)
// }
// db = tx.(*sql.Tx)
// }

return doGetQuery(ctx, s.conn, query, nil)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand Down Expand Up @@ -345,14 +348,11 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh
func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := prepareQueryForGetTables(cmd)

rows, err := s.db.QueryContext(ctx, query)
arrow, err := duckdb.NewArrowFromConn(s.conn)
if err != nil {
return nil, nil, err
}

var rdr array.RecordReader

rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema_ref.Tables, rows)
rdr, err := arrow.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -394,7 +394,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc

func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := "SELECT DISTINCT type AS table_type FROM sqlite_master"
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.TableTypes)
return doGetQuery(ctx, s.conn, query, schema_ref.TableTypes)
}

func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) {
Expand Down Expand Up @@ -422,6 +422,7 @@ func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context,

func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) {
var stmt *sql.Stmt
query := req.GetQuery()

if len(req.GetTransactionId()) > 0 {
tx, loaded := s.openTransactions.Load(string(req.GetTransactionId()))
Expand All @@ -438,7 +439,10 @@ func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req
}

handle := genRandomString()
s.prepared.Store(string(handle), Statement{stmt: stmt})
s.prepared.Store(string(handle), Statement{
stmt: stmt,
query: query,
})

result.Handle = handle
// no way to get the dataset or parameter schemas from sql.DB
Expand Down Expand Up @@ -473,29 +477,17 @@ type dbQueryCtx interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
}

func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {
rows, err := db.QueryContext(ctx, query, args...)
func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {

arrow, err := duckdb.NewArrowFromConn(conn)
if err != nil {
// Not really useful except for testing Flight SQL clients
trailers := metadata.Pairs("afsql-sqlite-query", query)
grpc.SetTrailer(ctx, trailers)
return nil, nil, err
}

var rdr *SqlBatchReader
if schema != nil {
rdr, err = NewSqlBatchReaderWithSchema(mem, schema, rows)
} else {
rdr, err = NewSqlBatchReader(mem, rows)
if err == nil {
schema = rdr.schema
}
}

rdr, err := arrow.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
}

schema = rdr.Schema()
ch := make(chan flight.StreamChunk)
go flight.StreamChunksFromReader(rdr, ch)
return schema, ch, nil
Expand All @@ -508,19 +500,18 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd
}

stmt := val.(Statement)
arrow, err := duckdb.NewArrowFromConn(s.conn)
if err != nil {
return nil, nil, err
}

readers := make([]array.RecordReader, 0, len(stmt.params))
if len(stmt.params) == 0 {
rows, err := stmt.stmt.QueryContext(ctx)
if err != nil {
return nil, nil, err
}

rdr, err := NewSqlBatchReader(s.Alloc, rows)
rdr, err := arrow.QueryContext(ctx, stmt.query)
if err != nil {
return nil, nil, err
}

schema = rdr.schema
schema = rdr.Schema()
readers = append(readers, rdr)
} else {
defer func() {
Expand All @@ -530,35 +521,17 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd
}
}
}()
var (
rows *sql.Rows
rdr *SqlBatchReader
)
// if we have multiple rows of bound params, execute the query
// multiple times and concatenate the result sets.
for _, p := range stmt.params {
rows, err = stmt.stmt.QueryContext(ctx, p...)
rdr, err := arrow.QueryContext(ctx, stmt.query, p...)
if err != nil {
return nil, nil, err
}

if schema == nil {
rdr, err = NewSqlBatchReader(s.Alloc, rows)
if err != nil {
return nil, nil, err
}
schema = rdr.schema
} else {
rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows)
if err != nil {
return nil, nil, err
}
}

schema = rdr.Schema()
readers = append(readers, rdr)
}
}

ch := make(chan flight.StreamChunk)
go flight.ConcatenateReaders(readers, ch)
out = ch
Expand Down Expand Up @@ -715,7 +688,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight

fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table)

return doGetQuery(ctx, s.Alloc, s.db, b.String(), schema_ref.PrimaryKeys)
return doGetQuery(ctx, s.conn, b.String(), schema_ref.PrimaryKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -731,7 +704,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh
filter += " AND fk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ImportedKeys)
return doGetQuery(ctx, s.conn, query, schema_ref.ImportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -747,7 +720,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh
filter += " AND pk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -773,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli
filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) {
Expand Down
13 changes: 12 additions & 1 deletion flightsqltest/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"os"
"strings"
Expand All @@ -32,6 +33,7 @@ import (
"time"

"github.com/apecloud/myduckserver/flightsqlserver"
"github.com/marcboeker/go-duckdb"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

Expand Down Expand Up @@ -96,7 +98,16 @@ func (s *SqlTestSuite) SetupSuite() {
if err != nil {
return nil, "", err
}
sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage())
conn, err := provider.Connector().Connect(context.Background())
if err != nil {
log.Fatal(err)
}

duckConn, ok := conn.(*duckdb.Conn)
if !ok {
log.Fatal("Failed to get DuckDB connection")
}
sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage(), duckConn)
if err != nil {
return nil, "", err
}
Expand Down
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -192,12 +193,18 @@ func main() {
if flightsqlPort > 0 {

db := provider.Storage()
conn, err := provider.Connector().Connect(context.Background())
if err != nil {
log.Fatal(err)
}
defer db.Close()

srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db)
duckConn, ok := conn.(*duckdb.Conn)
if !ok {
log.Fatal("Failed to get DuckDB connection")
}

srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db, duckConn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this connection serves all flightsql queries. While this is a valid approach, it is not good for concurrency. You can get a new duckdb.Conn whenever needed using the following method:

conn, err := s.db.Conn(context.Background())
// error handling...
var duckConn *duckdb.Conn
err = conn.Raw(func (driverConn any) error {
    duckConn = driverConn.(*duckdb.Conn)
    return nil
})
...

This can be defined as a helper method.

if err != nil {
log.Fatal(err)
}
Expand Down
Loading