Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 160 additions & 4 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@ package wire

import (
"context"
"errors"
"fmt"
"iter"
"sync"

"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/jeroenrinzema/psql-wire/pkg/types"
)

// Limit represents the maximum number of rows to be written.
// Zero denotes "no limit".
type Limit uint32

const NoLimit Limit = 0

type Statement struct {
fn PreparedStatementFn
parameters []uint32
Expand Down Expand Up @@ -77,15 +86,126 @@ type Portal struct {
statement *Statement
parameters []Parameter
formats []FormatCode

// The iterator state (created by iter.Pull)
next func() (struct{}, bool)
stop func()
// Filled in by dataWriter.Complete when the handler finishes. Used to
// return the tag when re-executing a completed portal.
tag string
err error
// Set to true when execution of the portal has finished.
done bool

// pending is closed when the most recently launched async goroutine
// finishes. A new goroutine for the same portal waits on this channel
// before starting, so same-portal executes serialize while different
// portals run in parallel.
pending chan struct{}
}

// Close tears down the portal from outside the execution goroutine (e.g.
// explicit Close message or portal re-bind). If an async goroutine is still
// running, cleanup is deferred to a background goroutine that waits for it.
func (p *Portal) Close() {
if p.pending != nil {
pending := p.pending
go func() {
<-pending
p.close()
}()
return
}
p.close()
}

// close tears down the iterator inline. Used from within execute where we
// are already inside the execution context and don't need to wait on pending.
func (p *Portal) close() {
if p.stop != nil {
p.stop()
p.next = nil
p.stop = nil
}
}

func portalSuspended(writer *buffer.Writer) error {
writer.Start(types.ServerPortalSuspended)
return writer.End()
}

func (p *Portal) execute(ctx context.Context, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error {
if p.done {
// Re-executing an already completed portal simply returns the tag or
// error again.
if p.err != nil {
return p.err
}
return commandComplete(writer, p.tag)
}

if p.next == nil {
// This is the first execute call on this portal. So let's start the
// execution. Otherwise we continue from where we left off.
session, _ := GetSession(ctx)
// Create a simple push-style iterator (iter.Seq) around the
// statement.fn.
seq := func(yield func(struct{}) bool) {
dw := &dataWriter{
ctx: ctx,
session: session,
columns: p.statement.columns,
formats: p.formats,
reader: reader,
client: writer,
yield: yield,
tag: &p.tag,
}
err := p.statement.fn(ctx, dw, p.parameters)
if err != nil && !errors.Is(err, ErrSuspendedHandlerClosed) {
p.err = err
}
}

// Then we convert that push-style iterator into a pull-style iterator,
// so we can suspend the iterator when we reach the row limit.
p.next, p.stop = iter.Pull(seq)
}

var count Limit
for {
if limit != NoLimit && count >= limit {
// We've reached the row limit. Suspend the portal and let the
// client know, so it can either issue a new Execute to continue or
// close the portal.
return portalSuspended(writer)
}

// Run the handler until it has produced the next row or finishes. The
// dataWriter inside the handler will call yield to "teleport" back
// here.
_, ok := p.next()
if !ok {
// The handler has finished. CommandComplete was already written
// by dataWriter.Complete.
p.close()
p.done = true
return p.err
}

count++
}
}

func DefaultPortalCacheFn() PortalCache {
return &DefaultPortalCache{}
}

type DefaultPortalCache struct {
portals map[string]*Portal
mu sync.RWMutex
portals map[string]*Portal
executing *Portal
closePending bool
mu sync.RWMutex
}

func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter, formats []FormatCode) error {
Expand All @@ -96,6 +216,10 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *St
cache.portals = map[string]*Portal{}
}

if existing, ok := cache.portals[name]; ok {
existing.Close()
}

cache.portals[name] = &Portal{
statement: stmt,
parameters: parameters,
Expand Down Expand Up @@ -145,13 +269,41 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit
return nil
}

session, _ := GetSession(ctx)
return portal.statement.fn(ctx, NewDataWriter(ctx, session, portal.statement.columns, portal.formats, limit, reader, writer), portal.parameters)
cache.mu.Lock()
cache.executing = portal
cache.mu.Unlock()

err = portal.execute(ctx, limit, reader, writer)

cache.mu.Lock()
cache.executing = nil
needsClose := cache.closePending
cache.closePending = false
cache.mu.Unlock()
if needsClose {
portal.close()
}

return err
}

// closePortal closes the portal immediately, unless it is the currently
// executing portal, in which case it marks the portal for deferred closing
// after the current Execute call returns.
func (cache *DefaultPortalCache) closePortal(portal *Portal) {
if portal == cache.executing {
cache.closePending = true
} else {
portal.Close()
}
}

func (cache *DefaultPortalCache) Delete(ctx context.Context, name string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
if portal, ok := cache.portals[name]; ok {
cache.closePortal(portal)
}
delete(cache.portals, name)
return nil
}
Expand All @@ -161,6 +313,7 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
defer cache.mu.Unlock()
for name, portal := range cache.portals {
if portal.statement == stmt {
cache.closePortal(portal)
delete(cache.portals, name)
}
}
Expand All @@ -170,5 +323,8 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St
func (cache *DefaultPortalCache) Close() {
cache.mu.Lock()
defer cache.mu.Unlock()
for _, portal := range cache.portals {
cache.closePortal(portal)
}
clear(cache.portals)
}
66 changes: 66 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package wire

import (
"bytes"
"context"
"log/slog"
"testing"

"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newDiscardWriter() *buffer.Writer {
return buffer.NewWriter(slog.Default(), &bytes.Buffer{})
}

// TestCloseFromHandler verifies that calling Close from within a handler does
// not deadlock. The currently executing portal is automatically deferred and
// closed after Execute returns.
func TestCloseFromHandler(t *testing.T) {
t.Parallel()
ctx := context.Background()
cache := &DefaultPortalCache{}

stmt := &Statement{
fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error {
cache.Close()
return writer.Complete("OK")
},
}

require.NoError(t, cache.Bind(ctx, "", stmt, nil, nil))
err := cache.Execute(ctx, "", NoLimit, nil, newDiscardWriter())
require.NoError(t, err)

portal, err := cache.Get(ctx, "")
require.NoError(t, err)
assert.Nil(t, portal)
}

// TestDeleteByStatementFromHandler verifies that calling DeleteByStatement
// from within a handler does not deadlock when the currently executing portal
// is bound to the deleted statement.
func TestDeleteByStatementFromHandler(t *testing.T) {
t.Parallel()
ctx := context.Background()
cache := &DefaultPortalCache{}

var selfStmt *Statement
selfStmt = &Statement{
fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error {
err := cache.DeleteByStatement(ctx, selfStmt)
require.NoError(t, err)
return writer.Complete("OK")
},
}

require.NoError(t, cache.Bind(ctx, "portal", selfStmt, nil, nil))
err := cache.Execute(ctx, "portal", NoLimit, nil, newDiscardWriter())
require.NoError(t, err)

portal, err := cache.Get(ctx, "portal")
require.NoError(t, err)
assert.Nil(t, portal)
}
Loading
Loading