Skip to content

Commit 606ab3d

Browse files
authored
fix: implement portal suspension for Execute row limits (#138)
PR #115 incorrectly implemented the Execute message's limit field by throwing an error when the row limit was reached. Per the PostgreSQL extended query protocol, Execute should instead send PortalSuspended and allow the client to issue further Execute messages against the same portal to continue reading the result set in batches. This implements correct portal suspension without changing the handler API by using iter.Pull to convert the push-style iteration into a pull-style iterator that can be paused and resumed across Execute calls. For compatibility with the parallel pipeline execution from #123, same-portal Execute messages are serialized via a pending channel that chains goroutines. Close messages arriving in the same pipeline batch defer portal teardown until the in-flight Execute completes. Fixes #137
1 parent d056e41 commit 606ab3d

13 files changed

Lines changed: 860 additions & 220 deletions

cache.go

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@ package wire
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"iter"
68
"sync"
79

810
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
11+
"github.com/jeroenrinzema/psql-wire/pkg/types"
912
)
1013

14+
// Limit represents the maximum number of rows to be written.
15+
// Zero denotes "no limit".
16+
type Limit uint32
17+
18+
const NoLimit Limit = 0
19+
1120
type Statement struct {
1221
fn PreparedStatementFn
1322
parameters []uint32
@@ -77,15 +86,126 @@ type Portal struct {
7786
statement *Statement
7887
parameters []Parameter
7988
formats []FormatCode
89+
90+
// The iterator state (created by iter.Pull)
91+
next func() (struct{}, bool)
92+
stop func()
93+
// Filled in by dataWriter.Complete when the handler finishes. Used to
94+
// return the tag when re-executing a completed portal.
95+
tag string
96+
err error
97+
// Set to true when execution of the portal has finished.
98+
done bool
99+
100+
// pending is closed when the most recently launched async goroutine
101+
// finishes. A new goroutine for the same portal waits on this channel
102+
// before starting, so same-portal executes serialize while different
103+
// portals run in parallel.
104+
pending chan struct{}
105+
}
106+
107+
// Close tears down the portal from outside the execution goroutine (e.g.
108+
// explicit Close message or portal re-bind). If an async goroutine is still
109+
// running, cleanup is deferred to a background goroutine that waits for it.
110+
func (p *Portal) Close() {
111+
if p.pending != nil {
112+
pending := p.pending
113+
go func() {
114+
<-pending
115+
p.close()
116+
}()
117+
return
118+
}
119+
p.close()
120+
}
121+
122+
// close tears down the iterator inline. Used from within execute where we
123+
// are already inside the execution context and don't need to wait on pending.
124+
func (p *Portal) close() {
125+
if p.stop != nil {
126+
p.stop()
127+
p.next = nil
128+
p.stop = nil
129+
}
130+
}
131+
132+
func portalSuspended(writer *buffer.Writer) error {
133+
writer.Start(types.ServerPortalSuspended)
134+
return writer.End()
135+
}
136+
137+
func (p *Portal) execute(ctx context.Context, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error {
138+
if p.done {
139+
// Re-executing an already completed portal simply returns the tag or
140+
// error again.
141+
if p.err != nil {
142+
return p.err
143+
}
144+
return commandComplete(writer, p.tag)
145+
}
146+
147+
if p.next == nil {
148+
// This is the first execute call on this portal. So let's start the
149+
// execution. Otherwise we continue from where we left off.
150+
session, _ := GetSession(ctx)
151+
// Create a simple push-style iterator (iter.Seq) around the
152+
// statement.fn.
153+
seq := func(yield func(struct{}) bool) {
154+
dw := &dataWriter{
155+
ctx: ctx,
156+
session: session,
157+
columns: p.statement.columns,
158+
formats: p.formats,
159+
reader: reader,
160+
client: writer,
161+
yield: yield,
162+
tag: &p.tag,
163+
}
164+
err := p.statement.fn(ctx, dw, p.parameters)
165+
if err != nil && !errors.Is(err, ErrSuspendedHandlerClosed) {
166+
p.err = err
167+
}
168+
}
169+
170+
// Then we convert that push-style iterator into a pull-style iterator,
171+
// so we can suspend the iterator when we reach the row limit.
172+
p.next, p.stop = iter.Pull(seq)
173+
}
174+
175+
var count Limit
176+
for {
177+
if limit != NoLimit && count >= limit {
178+
// We've reached the row limit. Suspend the portal and let the
179+
// client know, so it can either issue a new Execute to continue or
180+
// close the portal.
181+
return portalSuspended(writer)
182+
}
183+
184+
// Run the handler until it has produced the next row or finishes. The
185+
// dataWriter inside the handler will call yield to "teleport" back
186+
// here.
187+
_, ok := p.next()
188+
if !ok {
189+
// The handler has finished. CommandComplete was already written
190+
// by dataWriter.Complete.
191+
p.close()
192+
p.done = true
193+
return p.err
194+
}
195+
196+
count++
197+
}
80198
}
81199

82200
func DefaultPortalCacheFn() PortalCache {
83201
return &DefaultPortalCache{}
84202
}
85203

86204
type DefaultPortalCache struct {
87-
portals map[string]*Portal
88-
mu sync.RWMutex
205+
portals map[string]*Portal
206+
executing *Portal
207+
closePending bool
208+
mu sync.RWMutex
89209
}
90210

91211
func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter, formats []FormatCode) error {
@@ -96,6 +216,10 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *St
96216
cache.portals = map[string]*Portal{}
97217
}
98218

219+
if existing, ok := cache.portals[name]; ok {
220+
existing.Close()
221+
}
222+
99223
cache.portals[name] = &Portal{
100224
statement: stmt,
101225
parameters: parameters,
@@ -145,13 +269,41 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit
145269
return nil
146270
}
147271

148-
session, _ := GetSession(ctx)
149-
return portal.statement.fn(ctx, NewDataWriter(ctx, session, portal.statement.columns, portal.formats, limit, reader, writer), portal.parameters)
272+
cache.mu.Lock()
273+
cache.executing = portal
274+
cache.mu.Unlock()
275+
276+
err = portal.execute(ctx, limit, reader, writer)
277+
278+
cache.mu.Lock()
279+
cache.executing = nil
280+
needsClose := cache.closePending
281+
cache.closePending = false
282+
cache.mu.Unlock()
283+
if needsClose {
284+
portal.close()
285+
}
286+
287+
return err
288+
}
289+
290+
// closePortal closes the portal immediately, unless it is the currently
291+
// executing portal, in which case it marks the portal for deferred closing
292+
// after the current Execute call returns.
293+
func (cache *DefaultPortalCache) closePortal(portal *Portal) {
294+
if portal == cache.executing {
295+
cache.closePending = true
296+
} else {
297+
portal.Close()
298+
}
150299
}
151300

152301
func (cache *DefaultPortalCache) Delete(ctx context.Context, name string) error {
153302
cache.mu.Lock()
154303
defer cache.mu.Unlock()
304+
if portal, ok := cache.portals[name]; ok {
305+
cache.closePortal(portal)
306+
}
155307
delete(cache.portals, name)
156308
return nil
157309
}
@@ -161,6 +313,7 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
161313
defer cache.mu.Unlock()
162314
for name, portal := range cache.portals {
163315
if portal.statement == stmt {
316+
cache.closePortal(portal)
164317
delete(cache.portals, name)
165318
}
166319
}
@@ -170,5 +323,8 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
170323
func (cache *DefaultPortalCache) Close() {
171324
cache.mu.Lock()
172325
defer cache.mu.Unlock()
326+
for _, portal := range cache.portals {
327+
cache.closePortal(portal)
328+
}
173329
clear(cache.portals)
174330
}

cache_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package wire
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"log/slog"
7+
"testing"
8+
9+
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func newDiscardWriter() *buffer.Writer {
15+
return buffer.NewWriter(slog.Default(), &bytes.Buffer{})
16+
}
17+
18+
// TestCloseFromHandler verifies that calling Close from within a handler does
19+
// not deadlock. The currently executing portal is automatically deferred and
20+
// closed after Execute returns.
21+
func TestCloseFromHandler(t *testing.T) {
22+
t.Parallel()
23+
ctx := context.Background()
24+
cache := &DefaultPortalCache{}
25+
26+
stmt := &Statement{
27+
fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error {
28+
cache.Close()
29+
return writer.Complete("OK")
30+
},
31+
}
32+
33+
require.NoError(t, cache.Bind(ctx, "", stmt, nil, nil))
34+
err := cache.Execute(ctx, "", NoLimit, nil, newDiscardWriter())
35+
require.NoError(t, err)
36+
37+
portal, err := cache.Get(ctx, "")
38+
require.NoError(t, err)
39+
assert.Nil(t, portal)
40+
}
41+
42+
// TestDeleteByStatementFromHandler verifies that calling DeleteByStatement
43+
// from within a handler does not deadlock when the currently executing portal
44+
// is bound to the deleted statement.
45+
func TestDeleteByStatementFromHandler(t *testing.T) {
46+
t.Parallel()
47+
ctx := context.Background()
48+
cache := &DefaultPortalCache{}
49+
50+
var selfStmt *Statement
51+
selfStmt = &Statement{
52+
fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error {
53+
err := cache.DeleteByStatement(ctx, selfStmt)
54+
require.NoError(t, err)
55+
return writer.Complete("OK")
56+
},
57+
}
58+
59+
require.NoError(t, cache.Bind(ctx, "portal", selfStmt, nil, nil))
60+
err := cache.Execute(ctx, "portal", NoLimit, nil, newDiscardWriter())
61+
require.NoError(t, err)
62+
63+
portal, err := cache.Get(ctx, "portal")
64+
require.NoError(t, err)
65+
assert.Nil(t, portal)
66+
}

0 commit comments

Comments
 (0)