diff --git a/arrow.go b/arrow.go index 659aca08..5f9de4bc 100644 --- a/arrow.go +++ b/arrow.go @@ -64,6 +64,7 @@ import ( "database/sql/driver" "errors" "fmt" + "sync/atomic" "unsafe" "github.com/apache/arrow-go/v18/arrow" @@ -91,6 +92,86 @@ func NewArrowFromConn(driverConn driver.Conn) (*Arrow, error) { return &Arrow{c: dbConn}, nil } +// arrowStreamReader implements array.RecordReader for streaming DuckDB results. +type arrowStreamReader struct { + refCount int64 + + ctx context.Context + stmt *Stmt + res *C.duckdb_arrow + schema *arrow.Schema + rowCount uint64 + readCount uint64 + currentRec arrow.Record + err error +} + +// Retain increases the reference count by 1. +// Retain may be called simultaneously from multiple goroutines. +func (r *arrowStreamReader) Retain() { + atomic.AddInt64(&r.refCount, 1) +} + +// Release decreases the reference count by 1. +// When the reference count goes to zero, the memory is freed. +// Release may be called simultaneously from multiple goroutines. +func (r *arrowStreamReader) Release() { + if atomic.AddInt64(&r.refCount, -1) == 0 { + if r.currentRec != nil { + r.currentRec.Release() + r.currentRec = nil + } + if r.res != nil { + C.duckdb_destroy_arrow(r.res) + r.res = nil + } + if r.stmt != nil { + r.stmt.Close() + r.stmt = nil + } + } +} + +func (r *arrowStreamReader) Schema() *arrow.Schema { + return r.schema +} + +func (r *arrowStreamReader) Next() bool { + if r.currentRec != nil { + r.currentRec.Release() + r.currentRec = nil + } + + if r.readCount >= r.rowCount { + return false + } + + select { + case <-r.ctx.Done(): + r.err = r.ctx.Err() + return false + default: + } + + rec, err := queryArrowArray(r.res, r.schema) + if err != nil { + r.err = err + return false + } + + r.currentRec = rec + r.readCount += uint64(rec.NumRows()) + return true +} + +func (r *arrowStreamReader) Record() arrow.Record { + return r.currentRec +} + +func (r *arrowStreamReader) Err() error { + return r.err +} + // QueryContext prepares statements, executes them, returns Apache Arrow array.RecordReader as a result of the last // executed statement. Arguments are bound to the last statement. func (a *Arrow) QueryContext(ctx context.Context, query string, args ...any) (array.RecordReader, error) { @@ -123,48 +204,28 @@ func (a *Arrow) QueryContext(ctx context.Context, query string, args ...any) (ar if err != nil { return nil, err } - defer stmt.Close() res, err := a.execute(stmt, a.anyArgsToNamedArgs(args)) if err != nil { + stmt.Close() return nil, err } - defer C.duckdb_destroy_arrow(res) sc, err := a.queryArrowSchema(res) if err != nil { + C.duckdb_destroy_arrow(res) + stmt.Close() return nil, err } - var recs []arrow.Record - defer func() { - for _, r := range recs { - r.Release() - } - }() - - rowCount := uint64(C.duckdb_arrow_row_count(*res)) - - var retrievedRows uint64 - - for retrievedRows < rowCount { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - rec, err := a.queryArrowArray(res, sc) - if err != nil { - return nil, err - } - - recs = append(recs, rec) - - retrievedRows += uint64(rec.NumRows()) - } - - return array.NewRecordReader(sc, recs) + return &arrowStreamReader{ + refCount: 1, + ctx: ctx, + stmt: stmt, + res: res, + schema: sc, + rowCount: uint64(C.duckdb_arrow_row_count(*res)), + }, nil } // queryArrowSchema fetches the internal arrow schema from the arrow result. @@ -194,7 +255,7 @@ func (a *Arrow) queryArrowSchema(res *C.duckdb_arrow) (*arrow.Schema, error) { // // This function can be called multiple time to get next chunks, // which will free the previous out_array. -func (a *Arrow) queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Record, error) { +func queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Record, error) { arr := C.calloc(1, C.sizeof_struct_ArrowArray) defer func() { cdata.ReleaseCArrowArray((*cdata.CArrowArray)(arr))