Skip to content

Commit 2966574

Browse files
committed
Handle Execute limit correctly by sending PortalSuspended
1 parent 64e53ce commit 2966574

11 files changed

Lines changed: 740 additions & 217 deletions

cache.go

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@ package wire
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"iter"
68
"sync"
79

810
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
11+
"github.com/jeroenrinzema/psql-wire/pkg/types"
912
)
1013

14+
// Limit represents the maximum number of rows to be written.
15+
// Zero denotes "no limit".
16+
type Limit uint32
17+
18+
const NoLimit Limit = 0
19+
1120
type Statement struct {
1221
fn PreparedStatementFn
1322
parameters []uint32
@@ -73,6 +82,112 @@ type Portal struct {
7382
statement *Statement
7483
parameters []Parameter
7584
formats []FormatCode
85+
86+
// The iterator state (created by iter.Pull)
87+
next func() (struct{}, bool)
88+
stop func()
89+
// Filled in by dataWriter.Complete when the handler finishes. Used to
90+
// return the tag when re-executing a completed portal.
91+
tag string
92+
err error
93+
// Set to true when execution of the portal has finished.
94+
done bool
95+
96+
// pending is closed when the most recently launched async goroutine
97+
// finishes. A new goroutine for the same portal waits on this channel
98+
// before starting, so same-portal executes serialize while different
99+
// portals run in parallel.
100+
pending chan struct{}
101+
}
102+
103+
// Close tears down the portal from outside the execution goroutine (e.g.
104+
// explicit Close message or portal re-bind). If an async goroutine is still
105+
// running, cleanup is deferred to a background goroutine that waits for it.
106+
func (p *Portal) Close() {
107+
if p.pending != nil {
108+
pending := p.pending
109+
go func() {
110+
<-pending
111+
p.close()
112+
}()
113+
return
114+
}
115+
p.close()
116+
}
117+
118+
// close tears down the iterator inline. Used from within execute where we
119+
// are already inside the execution context and don't need to wait on pending.
120+
func (p *Portal) close() {
121+
if p.stop != nil {
122+
p.stop()
123+
p.next = nil
124+
p.stop = nil
125+
}
126+
}
127+
128+
func portalSuspended(writer *buffer.Writer) error {
129+
writer.Start(types.ServerPortalSuspended)
130+
return writer.End()
131+
}
132+
133+
func (p *Portal) execute(ctx context.Context, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error {
134+
if p.done {
135+
// Re-executing an already completed portal simply returns the tag or
136+
// error again.
137+
if p.err != nil {
138+
return p.err
139+
}
140+
return commandComplete(writer, p.tag)
141+
}
142+
143+
if p.next == nil {
144+
// This is the first execute call on this portal. So let's start the
145+
// execution. Otherwise we continue from where we left off.
146+
session, _ := GetSession(ctx)
147+
// Create a simple push-style iterator (iter.Seq) around the
148+
// statement.fn.
149+
seq := func(yield func(struct{}) bool) {
150+
dw := &dataWriter{
151+
ctx: ctx,
152+
session: session,
153+
columns: p.statement.columns,
154+
formats: p.formats,
155+
reader: reader,
156+
client: writer,
157+
yield: yield,
158+
tag: &p.tag,
159+
}
160+
err := p.statement.fn(ctx, dw, p.parameters)
161+
if err != nil && !errors.Is(err, ErrSuspendedHandlerClosed) {
162+
p.err = err
163+
}
164+
}
165+
166+
// Then we convert that push-style iterator into a pull-style iterator,
167+
// so we can suspend the iterator when we reach the row limit.
168+
p.next, p.stop = iter.Pull(seq)
169+
}
170+
171+
var count Limit
172+
for {
173+
if limit != NoLimit && count >= limit {
174+
// We've reached the row limit. Suspend the portal and let the
175+
// client know, so it can either issue a new Execute to continue or
176+
// close the portal.
177+
return portalSuspended(writer)
178+
}
179+
180+
_, ok := p.next()
181+
if !ok {
182+
// The handler has finished. CommandComplete was already written
183+
// by dataWriter.Complete.
184+
p.close()
185+
p.done = true
186+
return p.err
187+
}
188+
189+
count++
190+
}
76191
}
77192

78193
func DefaultPortalCacheFn() PortalCache {
@@ -92,6 +207,10 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *St
92207
cache.portals = map[string]*Portal{}
93208
}
94209

210+
if existing, ok := cache.portals[name]; ok {
211+
existing.Close()
212+
}
213+
95214
cache.portals[name] = &Portal{
96215
statement: stmt,
97216
parameters: parameters,
@@ -137,13 +256,15 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit
137256
return nil
138257
}
139258

140-
session, _ := GetSession(ctx)
141-
return portal.statement.fn(ctx, NewDataWriter(ctx, session, portal.statement.columns, portal.formats, limit, reader, writer), portal.parameters)
259+
return portal.execute(ctx, limit, reader, writer)
142260
}
143261

144262
func (cache *DefaultPortalCache) Delete(ctx context.Context, name string) error {
145263
cache.mu.Lock()
146264
defer cache.mu.Unlock()
265+
if portal, ok := cache.portals[name]; ok {
266+
portal.Close()
267+
}
147268
delete(cache.portals, name)
148269
return nil
149270
}
@@ -153,10 +274,17 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
153274
defer cache.mu.Unlock()
154275
for name, portal := range cache.portals {
155276
if portal.statement == stmt {
277+
portal.Close()
156278
delete(cache.portals, name)
157279
}
158280
}
159281
return nil
160282
}
161283

162-
func (cache *DefaultPortalCache) Close() {}
284+
func (cache *DefaultPortalCache) Close() {
285+
cache.mu.Lock()
286+
defer cache.mu.Unlock()
287+
for _, portal := range cache.portals {
288+
portal.Close()
289+
}
290+
}

command.go

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package wire
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
@@ -56,6 +57,7 @@ type Session struct {
5657
Statements StatementCache
5758
Portals PortalCache
5859
Attributes map[string]interface{}
60+
reader *buffer.Reader
5961

6062
// pipelining
6163
ParallelPipeline ParallelPipelineConfig
@@ -93,6 +95,7 @@ func isExtendedQueryMessage(t types.ClientMessage) bool {
9395
// This method keeps consuming messages until the client issues a close message
9496
// or the connection is terminated.
9597
func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
98+
srv.reader = reader
9699
srv.logger.Debug("ready for query... starting to consume commands")
97100

98101
err := readyForQuery(writer, types.ServerIdle)
@@ -326,7 +329,13 @@ func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader
326329
return srv.WriteError(writer, err)
327330
}
328331

329-
err = statements[index].fn(ctx, NewDataWriter(ctx, srv, statements[index].columns, nil, NoLimit, reader, writer), nil)
332+
portal := &Portal{
333+
statement: &Statement{
334+
fn: statements[index].fn,
335+
columns: statements[index].columns,
336+
},
337+
}
338+
err = portal.execute(ctx, NoLimit, reader, writer)
330339
if err != nil {
331340
return srv.WriteError(writer, err)
332341
}
@@ -753,7 +762,11 @@ func (srv *Session) handleClose(ctx context.Context, reader *buffer.Reader, writ
753762
return writer.End()
754763
}
755764

756-
// executePipelined handles Execute in parallel pipeline mode
765+
// executePipelined handles Execute in parallel pipeline mode. When the
766+
// Execute carries a row limit (or the portal already has a suspended
767+
// iterator) it is queued as a synchronous event that runs inline during
768+
// drain, because it needs the portal's iter.Pull2 state. Otherwise it
769+
// launches a goroutine for parallel execution.
757770
func (srv *Session) executePipelined(ctx context.Context, writer *buffer.Writer, name string, limit uint32) error {
758771
portal, err := srv.Portals.Get(ctx, name)
759772
if err != nil {
@@ -764,42 +777,46 @@ func (srv *Session) executePipelined(ctx context.Context, writer *buffer.Writer,
764777
return srv.drainQueueAndWriteError(ctx, writer, errors.New("unknown portal"))
765778
}
766779

767-
// Create result channel and queue the event
768-
resultChan := make(chan *QueuedDataWriter, 1)
769-
srv.ResponseQueue.Enqueue(NewExecuteEvent(resultChan, portal.formats))
780+
resultChan := make(chan *executeResult, 1)
781+
srv.ResponseQueue.Enqueue(NewExecuteEvent(resultChan))
770782

771-
// Launch async execution
772-
go srv.executeAsync(ctx, portal, limit, resultChan)
783+
prev := portal.pending
784+
done := make(chan struct{})
785+
portal.pending = done
786+
go srv.executeAsync(ctx, done, portal, Limit(limit), prev, resultChan)
773787

774788
return nil
775789
}
776790

777-
// executeAsync runs the portal execution in a separate goroutine
778-
func (srv *Session) executeAsync(ctx context.Context, portal *Portal, limit uint32, resultChan chan<- *QueuedDataWriter) {
791+
// executeAsync runs portal.execute in a separate goroutine, capturing the
792+
// wire output into a buffer. If prev is non-nil, it waits for the previous
793+
// goroutine on the same portal to finish before starting.
794+
func (srv *Session) executeAsync(ctx context.Context, done chan struct{}, portal *Portal, limit Limit, prev chan struct{}, resultChan chan<- *executeResult) {
795+
defer close(done)
779796
defer close(resultChan)
780797
defer func() {
781798
if r := recover(); r != nil {
782-
collector := NewQueuedDataWriter(ctx, portal.statement.columns, Limit(limit))
783-
collector.SetError(fmt.Errorf("panic during execution: %v", r))
784-
resultChan <- collector
799+
resultChan <- &executeResult{err: fmt.Errorf("panic during execution: %v", r)}
785800
}
786801
}()
787802

788-
srv.logger.Debug("starting async execution",
789-
slog.Bool("has_function", portal.statement.fn != nil))
803+
if prev != nil {
804+
<-prev
805+
}
790806

791-
collector := NewQueuedDataWriter(ctx, portal.statement.columns, Limit(limit))
807+
srv.logger.Debug("starting async execution")
792808

793-
err := portal.statement.fn(ctx, collector, portal.parameters)
794-
if err != nil {
795-
collector.SetError(err)
796-
}
809+
buf := &bytes.Buffer{}
810+
w := buffer.NewWriter(srv.logger, buf)
811+
812+
err := portal.execute(ctx, limit, srv.reader, w)
813+
814+
result := &executeResult{buf: buf, err: err}
797815

798816
srv.logger.Debug("async execution complete",
799-
slog.Bool("has_error", collector.GetError() != nil),
800-
slog.Int("rows", int(collector.Written())))
817+
slog.Bool("has_error", err != nil))
801818

802-
resultChan <- collector
819+
resultChan <- result
803820
}
804821

805822
// handleSync handles the Sync message (extended query protocol)
@@ -924,22 +941,16 @@ func (srv *Session) writeQueuedResponse(ctx context.Context, writer *buffer.Writ
924941
return srv.writeColumnDescription(ctx, writer, event.Formats, event.Columns)
925942

926943
case ResponseExecute:
927-
// Execute writes DataRows followed by CommandComplete
928944
if event.Result == nil {
929-
// No result yet, this shouldn't happen in normal flow
930945
return errors.New("execute event has no result")
931946
}
932947

933-
// Check for execution error
934-
if err := event.Result.GetError(); err != nil {
935-
return srv.WriteError(writer, err)
948+
if event.Result.err != nil {
949+
return srv.WriteError(writer, event.Result.err)
936950
}
937951

938-
// Use DataWriter for correct encoding
939-
// Note: We use NoLimit here because the result is already limited during execution
940-
dataWriter := NewDataWriter(ctx, srv, event.Result.Columns(), event.Formats, NoLimit, nil, writer)
941-
942-
return event.Result.Replay(ctx, dataWriter)
952+
_, err := writer.Write(event.Result.buf.Bytes())
953+
return err
943954

944955
default:
945956
return fmt.Errorf("unknown response event kind: %v", event.Kind)

0 commit comments

Comments
 (0)