Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions driverbase/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"io"
"sync"
"sync/atomic"

"github.com/apache/arrow-go/v18/arrow"
Expand Down Expand Up @@ -74,6 +75,9 @@ type BaseRecordReader struct {

keepReading chan bool
hasBatch chan bool
// concurrent usage is not actually allowed, but try to prevent weird
// things from happening
mu sync.Mutex
}

type BaseRecordReaderOptions struct {
Expand Down Expand Up @@ -113,7 +117,7 @@ func (rr *BaseRecordReader) Init(ctx context.Context, alloc memory.Allocator, pa
// Initialize the builder and get the first result set
if rr.params != nil {
if !rr.advanceParams() {
rr.Close()
rr.closeUnlocked()
if rr.err != nil {
return rr.err
}
Expand All @@ -129,7 +133,7 @@ func (rr *BaseRecordReader) Init(ctx context.Context, alloc memory.Allocator, pa
rr.schema, err = rr.impl.NextResultSet(rr.ctx, rr.paramBatch, rr.paramIndex)
if err != nil {
rr.err = err
rr.Close()
rr.closeUnlocked()
return err
}

Expand Down Expand Up @@ -166,6 +170,14 @@ func (rr *BaseRecordReader) Init(ctx context.Context, alloc memory.Allocator, pa
}

func (rr *BaseRecordReader) Close() {
rr.mu.Lock()
defer rr.mu.Unlock()
rr.closeUnlocked()
}

func (rr *BaseRecordReader) closeUnlocked() {
// NOTE: we assume you hold the lock.

if rr.nextBatch != nil {
rr.nextBatch.Release()
rr.nextBatch = nil
Expand Down Expand Up @@ -242,7 +254,7 @@ func (rr *BaseRecordReader) readBatch() bool {
if rows == 0 && rr.done {
// N.B. I believe rows == 0 implies rr.done here
// Clean up eagerly since we will return false below
rr.Close()
rr.closeUnlocked()
}
return rows > 0
}
Expand All @@ -251,12 +263,16 @@ func (rr *BaseRecordReader) Next() bool {
if rr.impl == nil || rr.err != nil {
return false
}

rr.mu.Lock()
defer rr.mu.Unlock()

if rr.nextBatch != nil {
rr.nextBatch.Release()
rr.nextBatch = nil
}
if rr.done {
rr.Close()
rr.closeUnlocked()
return false
}

Expand All @@ -269,6 +285,8 @@ func (rr *BaseRecordReader) Next() bool {
// more are available. If there are no bind parameters, it returns false
// immediately.
func (rr *BaseRecordReader) advanceParams() bool {
// NOTE: we assume you hold the lock.

if rr.params == nil {
return false
}
Expand All @@ -290,7 +308,9 @@ func (rr *BaseRecordReader) advanceParams() bool {
func (rr *BaseRecordReader) Release() {
newCount := atomic.AddInt64(&rr.refCount, -1)
if newCount == 0 {
rr.Close()
rr.mu.Lock()
defer rr.mu.Unlock()
rr.closeUnlocked()
}
DebugAssert(newCount >= 0, "refCount went negative in BaseRecordReader")
}
Expand All @@ -304,10 +324,14 @@ func (rr *BaseRecordReader) Schema() *arrow.Schema {
}

func (rr *BaseRecordReader) Record() arrow.RecordBatch {
rr.mu.Lock()
defer rr.mu.Unlock()
return rr.nextBatch
}

func (rr *BaseRecordReader) RecordBatch() arrow.RecordBatch {
rr.mu.Lock()
defer rr.mu.Unlock()
return rr.nextBatch
}

Expand Down
Loading