@@ -2,12 +2,21 @@ package wire
22
33import (
44 "context"
5+ "errors"
56 "fmt"
7+ "iter"
68 "sync"
79
810 "github.com/jeroenrinzema/psql-wire/pkg/buffer"
11+ "github.com/jeroenrinzema/psql-wire/pkg/types"
912)
1013
14+ // Limit represents the maximum number of rows to be written.
15+ // Zero denotes "no limit".
16+ type Limit uint32
17+
18+ const NoLimit Limit = 0
19+
1120type Statement struct {
1221 fn PreparedStatementFn
1322 parameters []uint32
@@ -77,15 +86,126 @@ type Portal struct {
7786 statement * Statement
7887 parameters []Parameter
7988 formats []FormatCode
89+
90+ // The iterator state (created by iter.Pull)
91+ next func () (struct {}, bool )
92+ stop func ()
93+ // Filled in by dataWriter.Complete when the handler finishes. Used to
94+ // return the tag when re-executing a completed portal.
95+ tag string
96+ err error
97+ // Set to true when execution of the portal has finished.
98+ done bool
99+
100+ // pending is closed when the most recently launched async goroutine
101+ // finishes. A new goroutine for the same portal waits on this channel
102+ // before starting, so same-portal executes serialize while different
103+ // portals run in parallel.
104+ pending chan struct {}
105+ }
106+
107+ // Close tears down the portal from outside the execution goroutine (e.g.
108+ // explicit Close message or portal re-bind). If an async goroutine is still
109+ // running, cleanup is deferred to a background goroutine that waits for it.
110+ func (p * Portal ) Close () {
111+ if p .pending != nil {
112+ pending := p .pending
113+ go func () {
114+ <- pending
115+ p .close ()
116+ }()
117+ return
118+ }
119+ p .close ()
120+ }
121+
122+ // close tears down the iterator inline. Used from within execute where we
123+ // are already inside the execution context and don't need to wait on pending.
124+ func (p * Portal ) close () {
125+ if p .stop != nil {
126+ p .stop ()
127+ p .next = nil
128+ p .stop = nil
129+ }
130+ }
131+
132+ func portalSuspended (writer * buffer.Writer ) error {
133+ writer .Start (types .ServerPortalSuspended )
134+ return writer .End ()
135+ }
136+
137+ func (p * Portal ) execute (ctx context.Context , limit Limit , reader * buffer.Reader , writer * buffer.Writer ) error {
138+ if p .done {
139+ // Re-executing an already completed portal simply returns the tag or
140+ // error again.
141+ if p .err != nil {
142+ return p .err
143+ }
144+ return commandComplete (writer , p .tag )
145+ }
146+
147+ if p .next == nil {
148+ // This is the first execute call on this portal. So let's start the
149+ // execution. Otherwise we continue from where we left off.
150+ session , _ := GetSession (ctx )
151+ // Create a simple push-style iterator (iter.Seq) around the
152+ // statement.fn.
153+ seq := func (yield func (struct {}) bool ) {
154+ dw := & dataWriter {
155+ ctx : ctx ,
156+ session : session ,
157+ columns : p .statement .columns ,
158+ formats : p .formats ,
159+ reader : reader ,
160+ client : writer ,
161+ yield : yield ,
162+ tag : & p .tag ,
163+ }
164+ err := p .statement .fn (ctx , dw , p .parameters )
165+ if err != nil && ! errors .Is (err , ErrSuspendedHandlerClosed ) {
166+ p .err = err
167+ }
168+ }
169+
170+ // Then we convert that push-style iterator into a pull-style iterator,
171+ // so we can suspend the iterator when we reach the row limit.
172+ p .next , p .stop = iter .Pull (seq )
173+ }
174+
175+ var count Limit
176+ for {
177+ if limit != NoLimit && count >= limit {
178+ // We've reached the row limit. Suspend the portal and let the
179+ // client know, so it can either issue a new Execute to continue or
180+ // close the portal.
181+ return portalSuspended (writer )
182+ }
183+
184+ // Run the handler until it has produced the next row or finishes. The
185+ // dataWriter inside the handler will call yield to "teleport" back
186+ // here.
187+ _ , ok := p .next ()
188+ if ! ok {
189+ // The handler has finished. CommandComplete was already written
190+ // by dataWriter.Complete.
191+ p .close ()
192+ p .done = true
193+ return p .err
194+ }
195+
196+ count ++
197+ }
80198}
81199
82200func DefaultPortalCacheFn () PortalCache {
83201 return & DefaultPortalCache {}
84202}
85203
86204type DefaultPortalCache struct {
87- portals map [string ]* Portal
88- mu sync.RWMutex
205+ portals map [string ]* Portal
206+ executing * Portal
207+ closePending bool
208+ mu sync.RWMutex
89209}
90210
91211func (cache * DefaultPortalCache ) Bind (ctx context.Context , name string , stmt * Statement , parameters []Parameter , formats []FormatCode ) error {
@@ -96,6 +216,10 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *St
96216 cache .portals = map [string ]* Portal {}
97217 }
98218
219+ if existing , ok := cache .portals [name ]; ok {
220+ existing .Close ()
221+ }
222+
99223 cache .portals [name ] = & Portal {
100224 statement : stmt ,
101225 parameters : parameters ,
@@ -145,13 +269,41 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit
145269 return nil
146270 }
147271
148- session , _ := GetSession (ctx )
149- return portal .statement .fn (ctx , NewDataWriter (ctx , session , portal .statement .columns , portal .formats , limit , reader , writer ), portal .parameters )
272+ cache .mu .Lock ()
273+ cache .executing = portal
274+ cache .mu .Unlock ()
275+
276+ err = portal .execute (ctx , limit , reader , writer )
277+
278+ cache .mu .Lock ()
279+ cache .executing = nil
280+ needsClose := cache .closePending
281+ cache .closePending = false
282+ cache .mu .Unlock ()
283+ if needsClose {
284+ portal .close ()
285+ }
286+
287+ return err
288+ }
289+
290+ // closePortal closes the portal immediately, unless it is the currently
291+ // executing portal, in which case it marks the portal for deferred closing
292+ // after the current Execute call returns.
293+ func (cache * DefaultPortalCache ) closePortal (portal * Portal ) {
294+ if portal == cache .executing {
295+ cache .closePending = true
296+ } else {
297+ portal .Close ()
298+ }
150299}
151300
152301func (cache * DefaultPortalCache ) Delete (ctx context.Context , name string ) error {
153302 cache .mu .Lock ()
154303 defer cache .mu .Unlock ()
304+ if portal , ok := cache .portals [name ]; ok {
305+ cache .closePortal (portal )
306+ }
155307 delete (cache .portals , name )
156308 return nil
157309}
@@ -161,6 +313,7 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
161313 defer cache .mu .Unlock ()
162314 for name , portal := range cache .portals {
163315 if portal .statement == stmt {
316+ cache .closePortal (portal )
164317 delete (cache .portals , name )
165318 }
166319 }
@@ -170,5 +323,8 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
170323func (cache * DefaultPortalCache ) Close () {
171324 cache .mu .Lock ()
172325 defer cache .mu .Unlock ()
326+ for _ , portal := range cache .portals {
327+ cache .closePortal (portal )
328+ }
173329 clear (cache .portals )
174330}
0 commit comments