From 33a28fe631ca813cf4849827363060c8ea47524f Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 27 Dec 2024 20:38:33 +0800 Subject: [PATCH 1/6] Migrate to DuckDB Arrow for Query Execution --- compatibility/flightsql/go/flightsql_test.go | 28 ++++- .../flightsql_test.cpython-313.pyc | Bin 0 -> 4276 bytes flightsqlserver/sqlite_server.go | 113 +++++++----------- flightsqltest/driver_test.go | 13 +- main.go | 9 +- 5 files changed, 86 insertions(+), 77 deletions(-) create mode 100644 compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc diff --git a/compatibility/flightsql/go/flightsql_test.go b/compatibility/flightsql/go/flightsql_test.go index 2171a4c..e41a145 100644 --- a/compatibility/flightsql/go/flightsql_test.go +++ b/compatibility/flightsql/go/flightsql_test.go @@ -87,18 +87,36 @@ func executeQueryAndVerify(cnxn adbc.Connection, query string, expectedResults [ record := rows.Record() numRows := record.NumRows() - id := record.Column(0).(*array.Int64) - name := record.Column(1).(*array.String) - value := record.Column(2).(*array.Int64) for i := 0; i < int(numRows); i++ { + var id, value int64 + switch idCol := record.Column(0).(type) { + case *array.Int32: + id = int64(idCol.Value(i)) + case *array.Int64: + id = idCol.Value(i) + default: + t.Fatalf("unexpected type for id column: %T", record.Column(0)) + } + + name := record.Column(1).(*array.String) + + switch valueCol := record.Column(2).(type) { + case *array.Int32: + value = int64(valueCol.Value(i)) + case *array.Int64: + value = valueCol.Value(i) + default: + t.Fatalf("unexpected type for value column: %T", record.Column(2)) + } + actualResults = append(actualResults, struct { id int64 name string value int64 }{ - id: id.Value(i), + id: id, name: name.Value(i), - value: value.Value(i), + value: value, }) } } diff --git a/compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc b/compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8994384ef02774bcbf9a8be6f007a24215c5c213 GIT binary patch literal 4276 zcmd55G0p{oc)gKm~WlNbCS(G1qG{5_i3NLL~Aaw4;dO9Rtvn_yJ~)5u0A0soqv6y2@VV z9Bu$P?D$cHim*rF$d;WN*rm}p37uY#u*3Op? zJKzbG*&IqlO0#0oa)!*2RHM!NmO@m$Pk;%POSQl_v#tD48^ERWTQB#~nI%{@nQik9SE?2Ulp)vhvN_ z3duX^M1;_LCDaUFTV_g4pok!c-iEf1H(yeF?ri2H=Tl_>Pe-b ztEOGlXhY5$ZgOcwE2=j^a9Jlw)m|znu8}oN!(B<{e1CFonaRS6yJ%bW0h!)fe__RI zhAWyYjF^h!c(|bt_SYpMa5CpMRG*-YFxtKU!LgtA---OS^T5;C(Z9F%thXmu+mq|< zgRAX>kCrz$Zg1D-S0`98&A`1Dm1XL5y4AqRz#NYLMjuNjSkeLe zdLdsf!4L>e$f0<}q>E0(U0org_HC0Nexd4cw!6-*CwQl1nS+4k~Q_xADSr+x0Vr6 zNM(fjiu6oo=PG>SwN+Guq2Y;iM)*Jw@vJy8oe~%Dq_EJ}P}WkGbbK`>j*O?o6N3YZ z#!`Nup-NxVk=QnpPH+_u468-r*u;wjfil5*yrY{rg24n1H46@0GUQ|4JJ{G2{mFfF z&HGS&f}X`Y?k(J1SdSlCjUQTz_k8~Cr?Fl4(s$G4gF|bv_kzv_3U{}^MB#YHs}O3% z>+zmH$9ukUT0q6aBdG}YyGV0tPh=agDIxiudlG!(f9bhz@NKT{+aN)sxp(oNGa+1rOt4-6zq!hfja4CqK%HqbQ*)SRJex+Gyr7|Gz1cubs` zt{o=9JqNcyUbMRHa}ByMNBrKu|VGS_2zU}T?7vBT7VtlB86KJaS}W^470wR$5TySTn`k>v|HovN_6jX;Zo#^ zC1kLGT(N-HM^F7)!q?we`Y7z^d6}EPh!2C|mweEuYcW zV*0DSb--@M;1~~6c$nvY&o>Y6+K$*7*ZdzK_N{nqjKk|R8u?|e##_KyTk!aXND!(= z0x3+NEr<=j1+n1(v60d^^#`mGyHKd|6#R~ukXY2nji#}cF%>8-3pp)-nfkL+R8;_N zu4tMo{iFvzpdmDQ zg;FWyvmvezV`tGWnwozEdSEyLjV*v&NRJFZ@N6gAQe_}L1W z8b*F8eZKqQ!mpX~x2(089V~uuVJ&v-RZATvDRS-iqw=eC zu;E<4mQ@Od|3_7c)&A5+HU{= literal 0 HcmV?d00001 diff --git a/flightsqlserver/sqlite_server.go b/flightsqlserver/sqlite_server.go index 4669b1a..1a74f05 100644 --- a/flightsqlserver/sqlite_server.go +++ b/flightsqlserver/sqlite_server.go @@ -55,10 +55,8 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/arrow/scalar" "github.com/marcboeker/go-duckdb" - - "google.golang.org/grpc" + _ "github.com/marcboeker/go-duckdb" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" _ "modernc.org/sqlite" ) @@ -210,19 +208,21 @@ func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) { type Statement struct { stmt *sql.Stmt + query string params [][]interface{} } type SQLiteFlightSQLServer struct { flightsql.BaseServer - db *sql.DB + db *sql.DB + conn *duckdb.Conn prepared sync.Map openTransactions sync.Map } -func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) { - ret := &SQLiteFlightSQLServer{db: db} +func NewSQLiteFlightSQLServer(db *sql.DB, conn *duckdb.Conn) (*SQLiteFlightSQLServer, error) { + ret := &SQLiteFlightSQLServer{db: db, conn: conn} ret.Alloc = memory.DefaultAllocator for k, v := range SqlInfoResultMap() { ret.RegisterSqlInfo(flightsql.SqlInfo(k), v) @@ -260,17 +260,20 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq if err != nil { return nil, nil, err } - - var db dbQueryCtx = s.db if txnid != "" { - tx, loaded := s.openTransactions.Load(txnid) - if !loaded { - return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid) - } - db = tx.(*sql.Tx) + return nil, nil, fmt.Errorf("transactions not yet supported with DuckDB") } - return doGetQuery(ctx, s.Alloc, db, query, nil) + // var db dbQueryCtx = s.db + // if txnid != "" { + // tx, loaded := s.openTransactions.Load(txnid) + // if !loaded { + // return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid) + // } + // db = tx.(*sql.Tx) + // } + + return doGetQuery(ctx, s.conn, query, nil) } func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -345,14 +348,11 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := prepareQueryForGetTables(cmd) - rows, err := s.db.QueryContext(ctx, query) + arrow, err := duckdb.NewArrowFromConn(s.conn) if err != nil { return nil, nil, err } - - var rdr array.RecordReader - - rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema_ref.Tables, rows) + rdr, err := arrow.QueryContext(ctx, query) if err != nil { return nil, nil, err } @@ -394,7 +394,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := "SELECT DISTINCT type AS table_type FROM sqlite_master" - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.TableTypes) + return doGetQuery(ctx, s.conn, query, schema_ref.TableTypes) } func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) { @@ -422,6 +422,7 @@ func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { var stmt *sql.Stmt + query := req.GetQuery() if len(req.GetTransactionId()) > 0 { tx, loaded := s.openTransactions.Load(string(req.GetTransactionId())) @@ -438,7 +439,10 @@ func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req } handle := genRandomString() - s.prepared.Store(string(handle), Statement{stmt: stmt}) + s.prepared.Store(string(handle), Statement{ + stmt: stmt, + query: query, + }) result.Handle = handle // no way to get the dataset or parameter schemas from sql.DB @@ -473,29 +477,17 @@ type dbQueryCtx interface { QueryContext(context.Context, string, ...any) (*sql.Rows, error) } -func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { - rows, err := db.QueryContext(ctx, query, args...) +func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { + + arrow, err := duckdb.NewArrowFromConn(conn) if err != nil { - // Not really useful except for testing Flight SQL clients - trailers := metadata.Pairs("afsql-sqlite-query", query) - grpc.SetTrailer(ctx, trailers) return nil, nil, err } - - var rdr *SqlBatchReader - if schema != nil { - rdr, err = NewSqlBatchReaderWithSchema(mem, schema, rows) - } else { - rdr, err = NewSqlBatchReader(mem, rows) - if err == nil { - schema = rdr.schema - } - } - + rdr, err := arrow.QueryContext(ctx, query, args...) if err != nil { return nil, nil, err } - + schema = rdr.Schema() ch := make(chan flight.StreamChunk) go flight.StreamChunksFromReader(rdr, ch) return schema, ch, nil @@ -508,19 +500,18 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd } stmt := val.(Statement) + arrow, err := duckdb.NewArrowFromConn(s.conn) + if err != nil { + return nil, nil, err + } + readers := make([]array.RecordReader, 0, len(stmt.params)) if len(stmt.params) == 0 { - rows, err := stmt.stmt.QueryContext(ctx) - if err != nil { - return nil, nil, err - } - - rdr, err := NewSqlBatchReader(s.Alloc, rows) + rdr, err := arrow.QueryContext(ctx, stmt.query) if err != nil { return nil, nil, err } - - schema = rdr.schema + schema = rdr.Schema() readers = append(readers, rdr) } else { defer func() { @@ -530,35 +521,17 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd } } }() - var ( - rows *sql.Rows - rdr *SqlBatchReader - ) // if we have multiple rows of bound params, execute the query // multiple times and concatenate the result sets. for _, p := range stmt.params { - rows, err = stmt.stmt.QueryContext(ctx, p...) + rdr, err := arrow.QueryContext(ctx, stmt.query, p...) if err != nil { return nil, nil, err } - - if schema == nil { - rdr, err = NewSqlBatchReader(s.Alloc, rows) - if err != nil { - return nil, nil, err - } - schema = rdr.schema - } else { - rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows) - if err != nil { - return nil, nil, err - } - } - + schema = rdr.Schema() readers = append(readers, rdr) } } - ch := make(chan flight.StreamChunk) go flight.ConcatenateReaders(readers, ch) out = ch @@ -715,7 +688,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table) - return doGetQuery(ctx, s.Alloc, s.db, b.String(), schema_ref.PrimaryKeys) + return doGetQuery(ctx, s.conn, b.String(), schema_ref.PrimaryKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -731,7 +704,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh filter += " AND fk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ImportedKeys) + return doGetQuery(ctx, s.conn, query, schema_ref.ImportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -747,7 +720,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh filter += " AND pk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -773,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) { diff --git a/flightsqltest/driver_test.go b/flightsqltest/driver_test.go index 8a79504..280ed1c 100644 --- a/flightsqltest/driver_test.go +++ b/flightsqltest/driver_test.go @@ -24,6 +24,7 @@ import ( "database/sql" "errors" "fmt" + "log" "math/rand" "os" "strings" @@ -32,6 +33,7 @@ import ( "time" "github.com/apecloud/myduckserver/flightsqlserver" + "github.com/marcboeker/go-duckdb" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -96,7 +98,16 @@ func (s *SqlTestSuite) SetupSuite() { if err != nil { return nil, "", err } - sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage()) + conn, err := provider.Connector().Connect(context.Background()) + if err != nil { + log.Fatal(err) + } + + duckConn, ok := conn.(*duckdb.Conn) + if !ok { + log.Fatal("Failed to get DuckDB connection") + } + sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage(), duckConn) if err != nil { return nil, "", err } diff --git a/main.go b/main.go index 6b6fec3..685d3cc 100644 --- a/main.go +++ b/main.go @@ -40,6 +40,7 @@ import ( "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" + "github.com/marcboeker/go-duckdb" _ "github.com/marcboeker/go-duckdb" "github.com/sirupsen/logrus" ) @@ -192,12 +193,18 @@ func main() { if flightsqlPort > 0 { db := provider.Storage() + conn, err := provider.Connector().Connect(context.Background()) if err != nil { log.Fatal(err) } defer db.Close() - srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db) + duckConn, ok := conn.(*duckdb.Conn) + if !ok { + log.Fatal("Failed to get DuckDB connection") + } + + srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db, duckConn) if err != nil { log.Fatal(err) } From 4e9278eb18498e4c283ca83a21b7432066d18246 Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 27 Dec 2024 21:16:14 +0800 Subject: [PATCH 2/6] feat:Migrate to DuckDB Arrow for Query Execution --- .../__pycache__/flightsql_test.cpython-313.pyc | Bin 4276 -> 0 bytes compatibility/flightsql/python/flightsql_test.py | 1 - 2 files changed, 1 deletion(-) delete mode 100644 compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc diff --git a/compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc b/compatibility/flightsql/python/__pycache__/flightsql_test.cpython-313.pyc deleted file mode 100644 index 8994384ef02774bcbf9a8be6f007a24215c5c213..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4276 zcmd55G0p{oc)gKm~WlNbCS(G1qG{5_i3NLL~Aaw4;dO9Rtvn_yJ~)5u0A0soqv6y2@VV z9Bu$P?D$cHim*rF$d;WN*rm}p37uY#u*3Op? zJKzbG*&IqlO0#0oa)!*2RHM!NmO@m$Pk;%POSQl_v#tD48^ERWTQB#~nI%{@nQik9SE?2Ulp)vhvN_ z3duX^M1;_LCDaUFTV_g4pok!c-iEf1H(yeF?ri2H=Tl_>Pe-b ztEOGlXhY5$ZgOcwE2=j^a9Jlw)m|znu8}oN!(B<{e1CFonaRS6yJ%bW0h!)fe__RI zhAWyYjF^h!c(|bt_SYpMa5CpMRG*-YFxtKU!LgtA---OS^T5;C(Z9F%thXmu+mq|< zgRAX>kCrz$Zg1D-S0`98&A`1Dm1XL5y4AqRz#NYLMjuNjSkeLe zdLdsf!4L>e$f0<}q>E0(U0org_HC0Nexd4cw!6-*CwQl1nS+4k~Q_xADSr+x0Vr6 zNM(fjiu6oo=PG>SwN+Guq2Y;iM)*Jw@vJy8oe~%Dq_EJ}P}WkGbbK`>j*O?o6N3YZ z#!`Nup-NxVk=QnpPH+_u468-r*u;wjfil5*yrY{rg24n1H46@0GUQ|4JJ{G2{mFfF z&HGS&f}X`Y?k(J1SdSlCjUQTz_k8~Cr?Fl4(s$G4gF|bv_kzv_3U{}^MB#YHs}O3% z>+zmH$9ukUT0q6aBdG}YyGV0tPh=agDIxiudlG!(f9bhz@NKT{+aN)sxp(oNGa+1rOt4-6zq!hfja4CqK%HqbQ*)SRJex+Gyr7|Gz1cubs` zt{o=9JqNcyUbMRHa}ByMNBrKu|VGS_2zU}T?7vBT7VtlB86KJaS}W^470wR$5TySTn`k>v|HovN_6jX;Zo#^ zC1kLGT(N-HM^F7)!q?we`Y7z^d6}EPh!2C|mweEuYcW zV*0DSb--@M;1~~6c$nvY&o>Y6+K$*7*ZdzK_N{nqjKk|R8u?|e##_KyTk!aXND!(= z0x3+NEr<=j1+n1(v60d^^#`mGyHKd|6#R~ukXY2nji#}cF%>8-3pp)-nfkL+R8;_N zu4tMo{iFvzpdmDQ zg;FWyvmvezV`tGWnwozEdSEyLjV*v&NRJFZ@N6gAQe_}L1W z8b*F8eZKqQ!mpX~x2(089V~uuVJ&v-RZATvDRS-iqw=eC zu;E<4mQ@Od|3_7c)&A5+HU{= diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py index 4d1c7d0..1eccffe 100644 --- a/compatibility/flightsql/python/flightsql_test.py +++ b/compatibility/flightsql/python/flightsql_test.py @@ -55,7 +55,6 @@ def test_drop_table(self): cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='intTable'") # Check if the table exists rows = cursor.fetchall() self.assertEqual(len(rows), 0, "Table 'intTable' should be dropped and not exist in the database.") - cursor.execute("COMMIT;") if __name__ == "__main__": unittest.main() \ No newline at end of file From 1871283682f6ea798626e706914a4b33c823512b Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 27 Dec 2024 21:29:16 +0800 Subject: [PATCH 3/6] feat:Migrate to DuckDB Arrow for Query Execution --- compatibility/flightsql/python/flightsql_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py index 1eccffe..47d3406 100644 --- a/compatibility/flightsql/python/flightsql_test.py +++ b/compatibility/flightsql/python/flightsql_test.py @@ -37,6 +37,7 @@ def test_insert_and_select(self): """Test inserting data and selecting it back to verify correctness.""" with self.conn.cursor() as cursor: # Insert sample data + cursor.execute("SET search_path TO myduck") cursor.execute("INSERT INTO intTable (id, name, value) VALUES (1, 'TestName', 100)") cursor.execute("INSERT INTO intTable (id, name, value) VALUES (2, 'AnotherName', 200)") From 8374d985e4da1104ab6cb703ed4bfbb1bb64ea6f Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 27 Dec 2024 21:33:47 +0800 Subject: [PATCH 4/6] feat:Migrate to DuckDB Arrow for Query Execution --- compatibility/flightsql/python/flightsql_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py index 47d3406..3860840 100644 --- a/compatibility/flightsql/python/flightsql_test.py +++ b/compatibility/flightsql/python/flightsql_test.py @@ -37,12 +37,11 @@ def test_insert_and_select(self): """Test inserting data and selecting it back to verify correctness.""" with self.conn.cursor() as cursor: # Insert sample data - cursor.execute("SET search_path TO myduck") - cursor.execute("INSERT INTO intTable (id, name, value) VALUES (1, 'TestName', 100)") - cursor.execute("INSERT INTO intTable (id, name, value) VALUES (2, 'AnotherName', 200)") + cursor.execute("INSERT INTO myduck.intTable (id, name, value) VALUES (1, 'TestName', 100)") + cursor.execute("INSERT INTO myduck.intTable (id, name, value) VALUES (2, 'AnotherName', 200)") # Select data from the table - cursor.execute("SELECT * FROM intTable") + cursor.execute("SELECT * FROM myduck.intTable") rows = cursor.fetchall() # Expected result after insertions From cd9c772c41cd9ee85bf4ea9ac9d81c18e992dafa Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 27 Dec 2024 21:55:23 +0800 Subject: [PATCH 5/6] feat:Migrate to DuckDB Arrow for Query Execution --- compatibility/flightsql/python/flightsql_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py index 3860840..576f486 100644 --- a/compatibility/flightsql/python/flightsql_test.py +++ b/compatibility/flightsql/python/flightsql_test.py @@ -32,16 +32,17 @@ def setUp(self): value INT ) """) # Create the table + self.conn.commit() def test_insert_and_select(self): """Test inserting data and selecting it back to verify correctness.""" with self.conn.cursor() as cursor: # Insert sample data - cursor.execute("INSERT INTO myduck.intTable (id, name, value) VALUES (1, 'TestName', 100)") - cursor.execute("INSERT INTO myduck.intTable (id, name, value) VALUES (2, 'AnotherName', 200)") + cursor.execute("INSERT INTO intTable (id, name, value) VALUES (1, 'TestName', 100)") + cursor.execute("INSERT INTO intTable (id, name, value) VALUES (2, 'AnotherName', 200)") # Select data from the table - cursor.execute("SELECT * FROM myduck.intTable") + cursor.execute("SELECT * FROM intTable") rows = cursor.fetchall() # Expected result after insertions From 4f76af4ead4095b66f116168e7ee746f630de3b5 Mon Sep 17 00:00:00 2001 From: d h Date: Fri, 3 Jan 2025 16:45:06 +0800 Subject: [PATCH 6/6] feat:Migrate to DuckDB Arrow for Query Execution --- flightsqlserver/sqlite_server.go | 50 ++++++++++++++++++++++++-------- flightsqltest/driver_test.go | 12 +------- main.go | 9 +----- 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/flightsqlserver/sqlite_server.go b/flightsqlserver/sqlite_server.go index 1a74f05..927dcd1 100644 --- a/flightsqlserver/sqlite_server.go +++ b/flightsqlserver/sqlite_server.go @@ -221,8 +221,8 @@ type SQLiteFlightSQLServer struct { openTransactions sync.Map } -func NewSQLiteFlightSQLServer(db *sql.DB, conn *duckdb.Conn) (*SQLiteFlightSQLServer, error) { - ret := &SQLiteFlightSQLServer{db: db, conn: conn} +func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) { + ret := &SQLiteFlightSQLServer{db: db} ret.Alloc = memory.DefaultAllocator for k, v := range SqlInfoResultMap() { ret.RegisterSqlInfo(flightsql.SqlInfo(k), v) @@ -273,7 +273,7 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq // db = tx.(*sql.Tx) // } - return doGetQuery(ctx, s.conn, query, nil) + return doGetQuery(ctx, s.db, query, nil) } func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -348,7 +348,16 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := prepareQueryForGetTables(cmd) - arrow, err := duckdb.NewArrowFromConn(s.conn) + conn, err := s.db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) + if err != nil { + return nil, nil, err + } + arrow, err := duckdb.NewArrowFromConn(duckConn) if err != nil { return nil, nil, err } @@ -394,7 +403,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := "SELECT DISTINCT type AS table_type FROM sqlite_master" - return doGetQuery(ctx, s.conn, query, schema_ref.TableTypes) + return doGetQuery(ctx, s.db, query, schema_ref.TableTypes) } func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) { @@ -477,9 +486,15 @@ type dbQueryCtx interface { QueryContext(context.Context, string, ...any) (*sql.Rows, error) } -func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { +func doGetQuery(ctx context.Context, db *sql.DB, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { - arrow, err := duckdb.NewArrowFromConn(conn) + conn, err := db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) + arrow, err := duckdb.NewArrowFromConn(duckConn) if err != nil { return nil, nil, err } @@ -495,12 +510,23 @@ func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *ar func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found") } + conn, err := s.db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) + if err != nil { + return nil, nil, err + } + stmt := val.(Statement) - arrow, err := duckdb.NewArrowFromConn(s.conn) + arrow, err := duckdb.NewArrowFromConn(duckConn) if err != nil { return nil, nil, err } @@ -688,7 +714,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table) - return doGetQuery(ctx, s.conn, b.String(), schema_ref.PrimaryKeys) + return doGetQuery(ctx, s.db, b.String(), schema_ref.PrimaryKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -704,7 +730,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh filter += " AND fk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.conn, query, schema_ref.ImportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ImportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -720,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh filter += " AND pk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -746,7 +772,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) { diff --git a/flightsqltest/driver_test.go b/flightsqltest/driver_test.go index 280ed1c..78ccfea 100644 --- a/flightsqltest/driver_test.go +++ b/flightsqltest/driver_test.go @@ -24,7 +24,6 @@ import ( "database/sql" "errors" "fmt" - "log" "math/rand" "os" "strings" @@ -33,7 +32,6 @@ import ( "time" "github.com/apecloud/myduckserver/flightsqlserver" - "github.com/marcboeker/go-duckdb" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -98,16 +96,8 @@ func (s *SqlTestSuite) SetupSuite() { if err != nil { return nil, "", err } - conn, err := provider.Connector().Connect(context.Background()) - if err != nil { - log.Fatal(err) - } - duckConn, ok := conn.(*duckdb.Conn) - if !ok { - log.Fatal("Failed to get DuckDB connection") - } - sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage(), duckConn) + sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage()) if err != nil { return nil, "", err } diff --git a/main.go b/main.go index 685d3cc..6b6fec3 100644 --- a/main.go +++ b/main.go @@ -40,7 +40,6 @@ import ( "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" - "github.com/marcboeker/go-duckdb" _ "github.com/marcboeker/go-duckdb" "github.com/sirupsen/logrus" ) @@ -193,18 +192,12 @@ func main() { if flightsqlPort > 0 { db := provider.Storage() - conn, err := provider.Connector().Connect(context.Background()) if err != nil { log.Fatal(err) } defer db.Close() - duckConn, ok := conn.(*duckdb.Conn) - if !ok { - log.Fatal("Failed to get DuckDB connection") - } - - srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db, duckConn) + srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db) if err != nil { log.Fatal(err) }