diff --git a/cache.go b/cache.go index 7991804..3fe6e3f 100644 --- a/cache.go +++ b/cache.go @@ -2,12 +2,21 @@ package wire import ( "context" + "errors" "fmt" + "iter" "sync" "github.com/jeroenrinzema/psql-wire/pkg/buffer" + "github.com/jeroenrinzema/psql-wire/pkg/types" ) +// Limit represents the maximum number of rows to be written. +// Zero denotes "no limit". +type Limit uint32 + +const NoLimit Limit = 0 + type Statement struct { fn PreparedStatementFn parameters []uint32 @@ -77,6 +86,115 @@ type Portal struct { statement *Statement parameters []Parameter formats []FormatCode + + // The iterator state (created by iter.Pull) + next func() (struct{}, bool) + stop func() + // Filled in by dataWriter.Complete when the handler finishes. Used to + // return the tag when re-executing a completed portal. + tag string + err error + // Set to true when execution of the portal has finished. + done bool + + // pending is closed when the most recently launched async goroutine + // finishes. A new goroutine for the same portal waits on this channel + // before starting, so same-portal executes serialize while different + // portals run in parallel. + pending chan struct{} +} + +// Close tears down the portal from outside the execution goroutine (e.g. +// explicit Close message or portal re-bind). If an async goroutine is still +// running, cleanup is deferred to a background goroutine that waits for it. +func (p *Portal) Close() { + if p.pending != nil { + pending := p.pending + go func() { + <-pending + p.close() + }() + return + } + p.close() +} + +// close tears down the iterator inline. Used from within execute where we +// are already inside the execution context and don't need to wait on pending. +func (p *Portal) close() { + if p.stop != nil { + p.stop() + p.next = nil + p.stop = nil + } +} + +func portalSuspended(writer *buffer.Writer) error { + writer.Start(types.ServerPortalSuspended) + return writer.End() +} + +func (p *Portal) execute(ctx context.Context, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error { + if p.done { + // Re-executing an already completed portal simply returns the tag or + // error again. + if p.err != nil { + return p.err + } + return commandComplete(writer, p.tag) + } + + if p.next == nil { + // This is the first execute call on this portal. So let's start the + // execution. Otherwise we continue from where we left off. + session, _ := GetSession(ctx) + // Create a simple push-style iterator (iter.Seq) around the + // statement.fn. + seq := func(yield func(struct{}) bool) { + dw := &dataWriter{ + ctx: ctx, + session: session, + columns: p.statement.columns, + formats: p.formats, + reader: reader, + client: writer, + yield: yield, + tag: &p.tag, + } + err := p.statement.fn(ctx, dw, p.parameters) + if err != nil && !errors.Is(err, ErrSuspendedHandlerClosed) { + p.err = err + } + } + + // Then we convert that push-style iterator into a pull-style iterator, + // so we can suspend the iterator when we reach the row limit. + p.next, p.stop = iter.Pull(seq) + } + + var count Limit + for { + if limit != NoLimit && count >= limit { + // We've reached the row limit. Suspend the portal and let the + // client know, so it can either issue a new Execute to continue or + // close the portal. + return portalSuspended(writer) + } + + // Run the handler until it has produced the next row or finishes. The + // dataWriter inside the handler will call yield to "teleport" back + // here. + _, ok := p.next() + if !ok { + // The handler has finished. CommandComplete was already written + // by dataWriter.Complete. + p.close() + p.done = true + return p.err + } + + count++ + } } func DefaultPortalCacheFn() PortalCache { @@ -84,8 +202,10 @@ func DefaultPortalCacheFn() PortalCache { } type DefaultPortalCache struct { - portals map[string]*Portal - mu sync.RWMutex + portals map[string]*Portal + executing *Portal + closePending bool + mu sync.RWMutex } func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter, formats []FormatCode) error { @@ -96,6 +216,10 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *St cache.portals = map[string]*Portal{} } + if existing, ok := cache.portals[name]; ok { + existing.Close() + } + cache.portals[name] = &Portal{ statement: stmt, parameters: parameters, @@ -145,13 +269,41 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit return nil } - session, _ := GetSession(ctx) - return portal.statement.fn(ctx, NewDataWriter(ctx, session, portal.statement.columns, portal.formats, limit, reader, writer), portal.parameters) + cache.mu.Lock() + cache.executing = portal + cache.mu.Unlock() + + err = portal.execute(ctx, limit, reader, writer) + + cache.mu.Lock() + cache.executing = nil + needsClose := cache.closePending + cache.closePending = false + cache.mu.Unlock() + if needsClose { + portal.close() + } + + return err +} + +// closePortal closes the portal immediately, unless it is the currently +// executing portal, in which case it marks the portal for deferred closing +// after the current Execute call returns. +func (cache *DefaultPortalCache) closePortal(portal *Portal) { + if portal == cache.executing { + cache.closePending = true + } else { + portal.Close() + } } func (cache *DefaultPortalCache) Delete(ctx context.Context, name string) error { cache.mu.Lock() defer cache.mu.Unlock() + if portal, ok := cache.portals[name]; ok { + cache.closePortal(portal) + } delete(cache.portals, name) return nil } @@ -161,6 +313,7 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St defer cache.mu.Unlock() for name, portal := range cache.portals { if portal.statement == stmt { + cache.closePortal(portal) delete(cache.portals, name) } } @@ -170,5 +323,8 @@ func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *St func (cache *DefaultPortalCache) Close() { cache.mu.Lock() defer cache.mu.Unlock() + for _, portal := range cache.portals { + cache.closePortal(portal) + } clear(cache.portals) } diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..412a0fb --- /dev/null +++ b/cache_test.go @@ -0,0 +1,66 @@ +package wire + +import ( + "bytes" + "context" + "log/slog" + "testing" + + "github.com/jeroenrinzema/psql-wire/pkg/buffer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newDiscardWriter() *buffer.Writer { + return buffer.NewWriter(slog.Default(), &bytes.Buffer{}) +} + +// TestCloseFromHandler verifies that calling Close from within a handler does +// not deadlock. The currently executing portal is automatically deferred and +// closed after Execute returns. +func TestCloseFromHandler(t *testing.T) { + t.Parallel() + ctx := context.Background() + cache := &DefaultPortalCache{} + + stmt := &Statement{ + fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error { + cache.Close() + return writer.Complete("OK") + }, + } + + require.NoError(t, cache.Bind(ctx, "", stmt, nil, nil)) + err := cache.Execute(ctx, "", NoLimit, nil, newDiscardWriter()) + require.NoError(t, err) + + portal, err := cache.Get(ctx, "") + require.NoError(t, err) + assert.Nil(t, portal) +} + +// TestDeleteByStatementFromHandler verifies that calling DeleteByStatement +// from within a handler does not deadlock when the currently executing portal +// is bound to the deleted statement. +func TestDeleteByStatementFromHandler(t *testing.T) { + t.Parallel() + ctx := context.Background() + cache := &DefaultPortalCache{} + + var selfStmt *Statement + selfStmt = &Statement{ + fn: func(ctx context.Context, writer DataWriter, _ []Parameter) error { + err := cache.DeleteByStatement(ctx, selfStmt) + require.NoError(t, err) + return writer.Complete("OK") + }, + } + + require.NoError(t, cache.Bind(ctx, "portal", selfStmt, nil, nil)) + err := cache.Execute(ctx, "portal", NoLimit, nil, newDiscardWriter()) + require.NoError(t, err) + + portal, err := cache.Get(ctx, "portal") + require.NoError(t, err) + assert.Nil(t, portal) +} diff --git a/command.go b/command.go index 8caebed..962f8d2 100644 --- a/command.go +++ b/command.go @@ -1,6 +1,7 @@ package wire import ( + "bytes" "context" "errors" "fmt" @@ -56,6 +57,7 @@ type Session struct { Statements StatementCache Portals PortalCache Attributes map[string]interface{} + reader *buffer.Reader // pipelining ParallelPipeline ParallelPipelineConfig @@ -93,6 +95,7 @@ func isExtendedQueryMessage(t types.ClientMessage) bool { // This method keeps consuming messages until the client issues a close message // or the connection is terminated. func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error { + srv.reader = reader srv.logger.Debug("ready for query... starting to consume commands") err := readyForQuery(writer, types.ServerIdle) @@ -328,7 +331,13 @@ func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader return srv.WriteError(writer, err) } - err = statements[index].fn(ctx, NewDataWriter(ctx, srv, statements[index].columns, nil, NoLimit, reader, writer), nil) + portal := &Portal{ + statement: &Statement{ + fn: statements[index].fn, + columns: statements[index].columns, + }, + } + err = portal.execute(ctx, NoLimit, reader, writer) if err != nil { return srv.WriteError(writer, err) } @@ -755,7 +764,11 @@ func (srv *Session) handleClose(ctx context.Context, reader *buffer.Reader, writ return writer.End() } -// executePipelined handles Execute in parallel pipeline mode +// executePipelined handles Execute in parallel pipeline mode. When the +// Execute carries a row limit (or the portal already has a suspended +// iterator) it is queued as a synchronous event that runs inline during +// drain, because it needs the portal's iter.Pull2 state. Otherwise it +// launches a goroutine for parallel execution. func (srv *Session) executePipelined(ctx context.Context, writer *buffer.Writer, name string, limit uint32) error { portal, err := srv.Portals.Get(ctx, name) if err != nil { @@ -766,42 +779,50 @@ func (srv *Session) executePipelined(ctx context.Context, writer *buffer.Writer, return srv.drainQueueAndWriteError(ctx, writer, errors.New("unknown portal")) } - // Create result channel and queue the event - resultChan := make(chan *QueuedDataWriter, 1) - srv.ResponseQueue.Enqueue(NewExecuteEvent(resultChan, portal.formats)) + resultChan := make(chan *executeResult, 1) + srv.ResponseQueue.Enqueue(NewExecuteEvent(resultChan)) - // Launch async execution - go srv.executeAsync(ctx, portal, limit, resultChan) + prev := portal.pending + done := make(chan struct{}) + portal.pending = done + go srv.executeAsync(ctx, done, portal, Limit(limit), prev, resultChan) return nil } -// executeAsync runs the portal execution in a separate goroutine -func (srv *Session) executeAsync(ctx context.Context, portal *Portal, limit uint32, resultChan chan<- *QueuedDataWriter) { +// executeAsync runs portal.execute in a separate goroutine, capturing the +// wire output into a buffer. If prev is non-nil, it waits for the previous +// goroutine on the same portal to finish before starting. +func (srv *Session) executeAsync(ctx context.Context, done chan struct{}, portal *Portal, limit Limit, prev chan struct{}, resultChan chan<- *executeResult) { + defer close(done) defer close(resultChan) defer func() { if r := recover(); r != nil { - collector := NewQueuedDataWriter(ctx, portal.statement.columns, Limit(limit)) - collector.SetError(fmt.Errorf("panic during execution: %v", r)) - resultChan <- collector + resultChan <- &executeResult{err: fmt.Errorf("panic during execution: %v", r)} } }() - srv.logger.Debug("starting async execution", - slog.Bool("has_function", portal.statement.fn != nil)) + if prev != nil { + <-prev + } - collector := NewQueuedDataWriter(ctx, portal.statement.columns, Limit(limit)) + // pgtype.Map.PlanEncode caches encoding plans internally, so concurrent + // goroutines sharing the same Map will race. Give each goroutine its own. + ctx = setTypeInfo(ctx, srv.newTypeMap()) - err := portal.statement.fn(ctx, collector, portal.parameters) - if err != nil { - collector.SetError(err) - } + srv.logger.Debug("starting async execution") + + buf := &bytes.Buffer{} + w := buffer.NewWriter(srv.logger, buf) + + err := portal.execute(ctx, limit, srv.reader, w) + + result := &executeResult{buf: buf, err: err} srv.logger.Debug("async execution complete", - slog.Bool("has_error", collector.GetError() != nil), - slog.Int("rows", int(collector.Written()))) + slog.Bool("has_error", err != nil)) - resultChan <- collector + resultChan <- result } // handleSync handles the Sync message (extended query protocol) @@ -926,22 +947,16 @@ func (srv *Session) writeQueuedResponse(ctx context.Context, writer *buffer.Writ return srv.writeColumnDescription(ctx, writer, event.Formats, event.Columns) case ResponseExecute: - // Execute writes DataRows followed by CommandComplete if event.Result == nil { - // No result yet, this shouldn't happen in normal flow return errors.New("execute event has no result") } - // Check for execution error - if err := event.Result.GetError(); err != nil { - return srv.WriteError(writer, err) + if event.Result.err != nil { + return srv.WriteError(writer, event.Result.err) } - // Use DataWriter for correct encoding - // Note: We use NoLimit here because the result is already limited during execution - dataWriter := NewDataWriter(ctx, srv, event.Result.Columns(), event.Formats, NoLimit, nil, writer) - - return event.Result.Replay(ctx, dataWriter) + _, err := writer.Write(event.Result.buf.Bytes()) + return err default: return fmt.Errorf("unknown response event kind: %v", event.Kind) diff --git a/command_execute_test.go b/command_execute_test.go index 6dd8ada..b327a98 100644 --- a/command_execute_test.go +++ b/command_execute_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "testing" "github.com/jackc/pgx/v5/pgtype" @@ -296,3 +297,260 @@ func TestHandleExecute_ParallelPipeline_AsyncPanic(t *testing.T) { require.NoError(t, err) assert.Equal(t, types.ServerErrorResponse, msgType) } + +func TestHandleExecute_ParallelPipeline_CloseWhilePending(t *testing.T) { + t.Parallel() + + ctx := context.Background() + typeMap := pgtype.NewMap() + ctx = setTypeInfo(ctx, typeMap) + + logger := slogt.New(t) + + started := make(chan struct{}) + proceed := make(chan struct{}) + + stmt := &Statement{ + fn: func(ctx context.Context, writer DataWriter, params []Parameter) error { + close(started) + <-proceed + if err := writer.Row([]any{"hello"}); err != nil { + return err + } + return writer.Complete("SELECT 1") + }, + parameters: []uint32{}, + columns: Columns{ + {Name: "greeting", Oid: pgtype.TextOID}, + }, + } + + portals := &DefaultPortalCache{} + err := portals.Bind(ctx, "portal1", stmt, nil, nil) + require.NoError(t, err) + + session := &Session{ + Server: &Server{logger: logger}, + Statements: &DefaultStatementCache{}, + Portals: portals, + ParallelPipeline: ParallelPipelineConfig{Enabled: true}, + ResponseQueue: NewResponseQueue(), + inExtendedQuery: true, + } + + outBuf := &bytes.Buffer{} + writer := buffer.NewWriter(logger, outBuf) + + // Launch async execute + err = session.handleExecute(ctx, mock.NewExecuteReader(t, logger, "portal1", 0), writer) + require.NoError(t, err) + + // Wait for the handler to start + <-started + + // Close the portal while the goroutine is still running — should not block + err = session.handleClose(ctx, mock.NewCloseReader(t, logger, 'P', "portal1"), writer) + require.NoError(t, err) + + // Let the handler finish + close(proceed) + + // Sync should flush the execute result and close complete + err = session.handleSync(ctx, writer) + require.NoError(t, err) + + responseReader := mock.NewReader(t, outBuf) + + // Execute result (DataRow + CommandComplete) + msgType, _, err := responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerDataRow, msgType) + + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerCommandComplete, msgType) + + // CloseComplete + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerCloseComplete, msgType) + + // ReadyForQuery + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerReady, msgType) +} + +func TestHandleExecute_ParallelPipeline_SamePortalSerializes(t *testing.T) { + t.Parallel() + + ctx := context.Background() + typeMap := pgtype.NewMap() + ctx = setTypeInfo(ctx, typeMap) + + logger := slogt.New(t) + + // Handler produces 3 rows. Two limited Execute(limit=1) calls on the same + // portal must serialize: the second waits for the first to finish before + // pulling the next row from the suspended iterator. + stmt := &Statement{ + fn: func(ctx context.Context, writer DataWriter, params []Parameter) error { + for i := 1; i <= 3; i++ { + if err := writer.Row([]any{fmt.Sprintf("row %d", i)}); err != nil { + return err + } + } + return writer.Complete("SELECT 3") + }, + parameters: []uint32{}, + columns: Columns{ + {Name: "result", Oid: pgtype.TextOID}, + }, + } + + portals := &DefaultPortalCache{} + err := portals.Bind(ctx, "portal1", stmt, nil, nil) + require.NoError(t, err) + + session := &Session{ + Server: &Server{logger: logger}, + Statements: &DefaultStatementCache{}, + Portals: portals, + ParallelPipeline: ParallelPipelineConfig{Enabled: true}, + ResponseQueue: NewResponseQueue(), + inExtendedQuery: true, + } + + outBuf := &bytes.Buffer{} + writer := buffer.NewWriter(logger, outBuf) + + // Queue two Execute(limit=1) on the same portal — second must wait for first + err = session.handleExecute(ctx, mock.NewExecuteReader(t, logger, "portal1", 1), writer) + require.NoError(t, err) + err = session.handleExecute(ctx, mock.NewExecuteReader(t, logger, "portal1", 1), writer) + require.NoError(t, err) + + assert.Equal(t, 2, session.ResponseQueue.Len()) + + err = session.handleSync(ctx, writer) + require.NoError(t, err) + + // Both executes should produce results: each gets 1 row + PortalSuspended. + responseReader := mock.NewReader(t, outBuf) + + // First execute: DataRow + PortalSuspended + msgType, _, err := responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerDataRow, msgType) + + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerPortalSuspended, msgType) + + // Second execute: DataRow + PortalSuspended + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerDataRow, msgType) + + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerPortalSuspended, msgType) + + // ReadyForQuery + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerReady, msgType) +} + +func TestHandleExecute_ParallelPipeline_DifferentPortalsParallel(t *testing.T) { + t.Parallel() + + ctx := context.Background() + typeMap := pgtype.NewMap() + ctx = setTypeInfo(ctx, typeMap) + + logger := slogt.New(t) + + bothStarted := make(chan struct{}) + portal1Started := make(chan struct{}) + portal2Started := make(chan struct{}) + + go func() { + <-portal1Started + <-portal2Started + close(bothStarted) + }() + + makeStmt := func(started chan struct{}) *Statement { + return &Statement{ + fn: func(ctx context.Context, writer DataWriter, params []Parameter) error { + close(started) + // Wait until both portals have started — proves they run in parallel + <-bothStarted + if err := writer.Row([]any{"ok"}); err != nil { + return err + } + return writer.Complete("SELECT 1") + }, + parameters: []uint32{}, + columns: Columns{ + {Name: "result", Oid: pgtype.TextOID}, + }, + } + } + + portals := &DefaultPortalCache{} + err := portals.Bind(ctx, "p1", makeStmt(portal1Started), nil, nil) + require.NoError(t, err) + err = portals.Bind(ctx, "p2", makeStmt(portal2Started), nil, nil) + require.NoError(t, err) + + session := &Session{ + Server: &Server{logger: logger}, + Statements: &DefaultStatementCache{}, + Portals: portals, + ParallelPipeline: ParallelPipelineConfig{Enabled: true}, + ResponseQueue: NewResponseQueue(), + inExtendedQuery: true, + } + + outBuf := &bytes.Buffer{} + writer := buffer.NewWriter(logger, outBuf) + + err = session.handleExecute(ctx, mock.NewExecuteReader(t, logger, "p1", 0), writer) + require.NoError(t, err) + err = session.handleExecute(ctx, mock.NewExecuteReader(t, logger, "p2", 0), writer) + require.NoError(t, err) + + assert.Equal(t, 2, session.ResponseQueue.Len()) + + // Sync drains both — if they didn't run in parallel, they'd deadlock + // waiting on bothStarted + err = session.handleSync(ctx, writer) + require.NoError(t, err) + + responseReader := mock.NewReader(t, outBuf) + + // Portal 1: DataRow + CommandComplete + msgType, _, err := responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerDataRow, msgType) + + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerCommandComplete, msgType) + + // Portal 2: DataRow + CommandComplete + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerDataRow, msgType) + + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerCommandComplete, msgType) + + // ReadyForQuery + msgType, _, err = responseReader.ReadTypedMsg() + require.NoError(t, err) + assert.Equal(t, types.ServerReady, msgType) +} diff --git a/command_test.go b/command_test.go index ae3c7b8..1ac6b26 100644 --- a/command_test.go +++ b/command_test.go @@ -169,3 +169,131 @@ func TestServerLimit(t *testing.T) { client.Close(t) } + +func TestPortalSuspended(t *testing.T) { + t.Parallel() + + totalRows := 5 + columns := Columns{ + { + Table: 0, + Name: "id", + Oid: pgtype.Int4OID, + Width: 4, + }, + } + + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + for i := 0; i < totalRows; i++ { + if err := writer.Row([]any{int32(i)}); err != nil { + return err + } + } + return writer.Complete("SELECT 5") + } + + return Prepared(NewStatement(handle, WithColumns(columns))), nil + } + + server, err := NewServer(handler, Logger(slogt.New(t))) + require.NoError(t, err) + + address := TListenAndServe(t, server) + conn, err := net.Dial("tcp", address.String()) + require.NoError(t, err) + + client := mock.NewClient(t, conn) + client.Handshake(t) + client.Authenticate(t) + client.ReadyForQuery(t) + + for cycle := 0; cycle < 2; cycle++ { + t.Logf("cycle %d", cycle) + + client.Parse(t, "stmt1", "SELECT id") + client.ExpectMsg(t, types.ServerParseComplete) + + client.Bind(t, "portal1", "stmt1") + client.ExpectMsg(t, types.ServerBindComplete) + + // Execute with limit=2 — should get rows 0 and 1 + client.Execute(t, "portal1", 2) + rows := client.ExpectDataRows(t, 2) + assert.Equal(t, "0", string(rows[0][0])) + assert.Equal(t, "1", string(rows[1][0])) + client.ExpectMsg(t, types.ServerPortalSuspended) + + // Execute with limit=10 (more than remaining 3 rows) — should get rows 2, 3, 4 + client.Execute(t, "portal1", 10) + rows = client.ExpectDataRows(t, 3) + assert.Equal(t, "2", string(rows[0][0])) + assert.Equal(t, "3", string(rows[1][0])) + assert.Equal(t, "4", string(rows[2][0])) + client.ExpectMsg(t, types.ServerCommandComplete) + + client.Sync(t) + client.ExpectMsg(t, types.ServerReady) + } + + client.Close(t) +} + +func TestReExecuteCompletedPortal(t *testing.T) { + t.Parallel() + + columns := Columns{ + { + Table: 0, + Name: "id", + Oid: pgtype.Int4OID, + Width: 4, + }, + } + + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + if err := writer.Row([]any{int32(1)}); err != nil { + return err + } + return writer.Complete("SELECT 1") + } + + return Prepared(NewStatement(handle, WithColumns(columns))), nil + } + + server, err := NewServer(handler, Logger(slogt.New(t))) + require.NoError(t, err) + + address := TListenAndServe(t, server) + conn, err := net.Dial("tcp", address.String()) + require.NoError(t, err) + + client := mock.NewClient(t, conn) + client.Handshake(t) + client.Authenticate(t) + client.ReadyForQuery(t) + + client.Parse(t, "stmt1", "SELECT id") + client.ExpectMsg(t, types.ServerParseComplete) + + client.Bind(t, "portal1", "stmt1") + client.ExpectMsg(t, types.ServerBindComplete) + + // Execute with limit=2 — query only produces 1 row + client.Execute(t, "portal1", 2) + rows := client.ExpectDataRows(t, 1) + assert.Equal(t, "1", string(rows[0][0])) + tag := client.ExpectCommandComplete(t) + assert.Equal(t, "SELECT 1", tag) + + // Re-execute the same portal — should get CommandComplete with the same tag + client.Execute(t, "portal1", 0) + reTag := client.ExpectCommandComplete(t) + assert.Equal(t, "SELECT 1", reTag) + + client.Sync(t) + client.ExpectMsg(t, types.ServerReady) + + client.Close(t) +} diff --git a/error.go b/error.go index 1ac0c65..afcc81f 100644 --- a/error.go +++ b/error.go @@ -28,6 +28,10 @@ const ( // a trailing ReadyForQuery. Use this in contexts where no session is available // (e.g. authentication) or where you need to control ReadyForQuery yourself. func WriteUnterminatedError(writer *buffer.Writer, err error) error { + if writer.ErrorSanitizer != nil { + err = writer.ErrorSanitizer(err) + } + desc := psqlerr.Flatten(err) writer.Start(types.ServerErrorResponse) diff --git a/error_test.go b/error_test.go index e9595bb..512b62a 100644 --- a/error_test.go +++ b/error_test.go @@ -358,3 +358,46 @@ func TestDiscardUntilSync(t *testing.T) { _, _, err = responseReader.ReadTypedMsg() require.Error(t, err) } + +// TestRowReturnsEncodeError verifies that when columns.Write fails on the pull +// side (e.g. a type the pgx TypeMap can't encode), the encoding error is +// returned from DataWriter.Row so the handler can see and wrap it. +func TestRowReturnsEncodeError(t *testing.T) { + t.Parallel() + + var rowErr error + + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + columns := Columns{{Name: "val", Oid: pgtype.Int4OID}} + + stmt := NewStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + rowErr = writer.Row([]any{struct{}{}}) + if rowErr != nil { + return rowErr + } + return writer.Complete("SELECT 1") + }, WithColumns(columns)) + + return Prepared(stmt), nil + } + + server, err := NewServer(handler, Logger(slogt.New(t))) + require.NoError(t, err) + + address := TListenAndServe(t, server) + + ctx := context.Background() + connstr := fmt.Sprintf("postgres://%s:%d", address.IP, address.Port) + conn, err := pgx.Connect(ctx, connstr) + require.NoError(t, err) + + rows, _ := conn.Query(ctx, "SELECT 1;") + rows.Close() + assert.Error(t, rows.Err()) + + require.NotNil(t, rowErr) + assert.Contains(t, rowErr.Error(), "unable to encode") + + err = conn.Close(ctx) + assert.NoError(t, err) +} diff --git a/options.go b/options.go index 076aeba..a2b2bf1 100644 --- a/options.go +++ b/options.go @@ -102,7 +102,7 @@ type PortalCache interface { // Get attempts to get the portal for the given name. An error is returned // when no portal has been found. Get(ctx context.Context, name string) (*Portal, error) - // Execute executes the prepared statement with the given name and parameters. + // Execute executes the portal with the given name and parameters. Execute(ctx context.Context, name string, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error // Delete removes the portal with the given name. Deleting a nonexistent // name is not an error. @@ -195,6 +195,17 @@ func ParallelPipeline(config ParallelPipelineConfig) OptionFn { } } +// ErrorSanitizer sets a function that transforms errors before they are sent +// to the client. This hook is called before writing any ErrorResponse to the +// wire, including during authentication. It can be used to mask internal error +// details, generate error IDs, or rewrite error messages. +func ErrorSanitizer(fn func(error) error) OptionFn { + return func(srv *Server) error { + srv.ErrorSanitizer = fn + return nil + } +} + // MessageBufferSize sets the message buffer size which is allocated once a new // connection gets constructed. If a negative value or zero value is provided is // the default message buffer size used. diff --git a/pkg/buffer/writer.go b/pkg/buffer/writer.go index 58a02ee..1f08dcc 100644 --- a/pkg/buffer/writer.go +++ b/pkg/buffer/writer.go @@ -12,10 +12,11 @@ import ( // Writer provides a convenient way to write pgwire protocol messages type Writer struct { io.Writer - logger *slog.Logger - frame bytes.Buffer - putbuf [64]byte // buffer used to construct messages which could be written to the writer frame buffer - err error + logger *slog.Logger + frame bytes.Buffer + putbuf [64]byte // buffer used to construct messages which could be written to the writer frame buffer + err error + ErrorSanitizer func(error) error } // NewWriter constructs a new Postgres buffered message writer for the given io.Writer diff --git a/pkg/mock/client.go b/pkg/mock/client.go index 96b893f..49e235c 100644 --- a/pkg/mock/client.go +++ b/pkg/mock/client.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/jeroenrinzema/psql-wire/pkg/types" + "github.com/stretchr/testify/require" ) func NewClient(t *testing.T, conn net.Conn) *Client { @@ -145,3 +146,98 @@ func (client *Client) Close(t *testing.T) { t.Fatal(err) } } + +// Parse sends a Parse message with the given statement name and query. +func (client *Client) Parse(t *testing.T, name, query string) { + t.Helper() + client.Start(types.ClientParse) + client.AddString(name) + client.AddNullTerminate() + client.AddString(query) + client.AddNullTerminate() + client.AddInt16(0) // no parameter types + require.NoError(t, client.End()) +} + +// Bind sends a Bind message binding the given statement to the given portal. +func (client *Client) Bind(t *testing.T, portal, statement string) { + t.Helper() + client.Start(types.ClientBind) + client.AddString(portal) + client.AddNullTerminate() + client.AddString(statement) + client.AddNullTerminate() + client.AddInt16(0) // no parameter formats + client.AddInt16(0) // no parameter values + client.AddInt16(0) // no result format codes + require.NoError(t, client.End()) +} + +// Execute sends an Execute message for the given portal with the given row limit. +// A limit of 0 means no limit. +func (client *Client) Execute(t *testing.T, portal string, limit int32) { + t.Helper() + client.Start(types.ClientExecute) + client.AddString(portal) + client.AddNullTerminate() + client.AddInt32(limit) + require.NoError(t, client.End()) +} + +// Sync sends a Sync message. +func (client *Client) Sync(t *testing.T) { + t.Helper() + client.Start(types.ClientSync) + require.NoError(t, client.End()) +} + +// ExpectMsg reads one message and asserts its type matches expected. +func (client *Client) ExpectMsg(t *testing.T, expected types.ServerMessage) { + t.Helper() + ct, _, err := client.ReadTypedMsg() + require.NoError(t, err) + require.Equal(t, expected, ct, "expected %s but got %s", expected, ct) +} + +// ExpectDataRow reads a single DataRow message and returns the column values +// as raw byte slices (nil for SQL NULL). +func (client *Client) ExpectDataRow(t *testing.T) [][]byte { + t.Helper() + client.ExpectMsg(t, types.ServerDataRow) + + numCols, err := client.GetUint16() + require.NoError(t, err) + + row := make([][]byte, numCols) + for i := 0; i < int(numCols); i++ { + length, err := client.GetInt32() + require.NoError(t, err) + if length == -1 { + row[i] = nil + continue + } + val, err := client.GetBytes(int(length)) + require.NoError(t, err) + row[i] = val + } + return row +} + +// ExpectCommandComplete reads a CommandComplete message and returns the tag string. +func (client *Client) ExpectCommandComplete(t *testing.T) string { + t.Helper() + client.ExpectMsg(t, types.ServerCommandComplete) + tag, err := client.GetString() + require.NoError(t, err) + return tag +} + +// ExpectDataRows reads exactly n DataRow messages and returns all rows. +func (client *Client) ExpectDataRows(t *testing.T, n int) [][][]byte { + t.Helper() + rows := make([][][]byte, n) + for i := 0; i < n; i++ { + rows[i] = client.ExpectDataRow(t) + } + return rows +} diff --git a/queued_data_writer.go b/queued_data_writer.go deleted file mode 100644 index 22a50b0..0000000 --- a/queued_data_writer.go +++ /dev/null @@ -1,95 +0,0 @@ -package wire - -import ( - "context" - "errors" -) - -// QueuedDataWriter implements DataWriter interface -// It collects query results in memory for later replay during pipelining -type QueuedDataWriter struct { - columns Columns - rows [][]any - tag string - empty bool - written uint32 - err error - limit Limit -} - -// Implement DataWriter interface - -func (rc *QueuedDataWriter) Row(values []any) error { - if rc.err != nil { - return rc.err - } - - rc.rows = append(rc.rows, values) - rc.written++ - return nil -} - -func (rc *QueuedDataWriter) Complete(tag string) error { - rc.tag = tag - return nil -} - -func (rc *QueuedDataWriter) Empty() error { - rc.empty = true - return nil -} - -func (rc *QueuedDataWriter) Columns() Columns { - return rc.columns -} - -func (rc *QueuedDataWriter) Written() uint32 { - return rc.written -} - -func (rc *QueuedDataWriter) CopyIn(format FormatCode) (*CopyReader, error) { - return nil, errors.New("CopyIn not supported in pipeline mode") -} - -func (rc *QueuedDataWriter) Limit() uint32 { - return uint32(rc.limit) -} - -// SetError sets the error state -func (rc *QueuedDataWriter) SetError(err error) { - rc.err = err -} - -// GetError gets the error state -func (rc *QueuedDataWriter) GetError() error { - return rc.err -} - -// Replay writes all collected data to a real DataWriter -func (rc *QueuedDataWriter) Replay(ctx context.Context, writer DataWriter) error { - if rc.err != nil { - return rc.err - } - - // Write all collected rows - for _, row := range rc.rows { - if err := writer.Row(row); err != nil { - return err - } - } - - // Send completion - if rc.tag != "" { - return writer.Complete(rc.tag) - } - - return nil -} - -// NewQueuedDataWriter creates a DataWriter that collects results for pipelining -func NewQueuedDataWriter(ctx context.Context, columns Columns, limit Limit) *QueuedDataWriter { - return &QueuedDataWriter{ - columns: columns, - limit: limit, - } -} diff --git a/response_queue.go b/response_queue.go index dfdfdc4..da993d2 100644 --- a/response_queue.go +++ b/response_queue.go @@ -1,6 +1,7 @@ package wire import ( + "bytes" "context" ) @@ -37,8 +38,15 @@ type ResponseEvent struct { Formats []FormatCode // For ResponseExecute: tracks completion and results - ResultChannel chan *QueuedDataWriter // channel to receive results - Result *QueuedDataWriter // cached result once received + ResultChannel chan *executeResult // channel to receive results + Result *executeResult // cached result once received +} + +// executeResult holds the raw wire bytes or error produced by an async +// portal execution in the parallel pipeline. +type executeResult struct { + buf *bytes.Buffer + err error } // NewParseCompleteEvent creates a ParseComplete response event @@ -81,11 +89,10 @@ func NewCloseCompleteEvent() *ResponseEvent { } // NewExecuteEvent creates an Execute response event -func NewExecuteEvent(resultChan chan *QueuedDataWriter, formats []FormatCode) *ResponseEvent { +func NewExecuteEvent(resultChan chan *executeResult) *ResponseEvent { return &ResponseEvent{ Kind: ResponseExecute, ResultChannel: resultChan, - Formats: formats, } } @@ -120,10 +127,10 @@ func (q *ResponseQueue) DrainSync(ctx context.Context) ([]*ResponseEvent, error) case res := <-event.ResultChannel: event.Result = res // Check if the result contains an error - if res != nil && res.GetError() != nil { + if res != nil && res.err != nil { // Return events processed so far,not including the error event // Events after this one won't be sent on the wire - return processedEvents, res.GetError() + return processedEvents, res.err } case <-ctx.Done(): // Context cancelled - return events processed up to this point @@ -137,7 +144,6 @@ func (q *ResponseQueue) DrainSync(ctx context.Context) ([]*ResponseEvent, error) processedEvents = append(processedEvents, event) } - // All events processed successfully return processedEvents, nil } diff --git a/response_queue_test.go b/response_queue_test.go index 53fc439..99efe44 100644 --- a/response_queue_test.go +++ b/response_queue_test.go @@ -1,6 +1,7 @@ package wire import ( + "bytes" "context" "errors" "testing" @@ -13,23 +14,21 @@ import ( // newPendingExecuteEvent creates an Execute event that blocks (not ready). func newPendingExecuteEvent() *ResponseEvent { - return NewExecuteEvent(make(chan *QueuedDataWriter), nil) + return NewExecuteEvent(make(chan *executeResult)) } // newReadyExecuteEvent creates an Execute event with a completed result. -func newReadyExecuteEvent(rows [][]any) *ResponseEvent { - ch := make(chan *QueuedDataWriter, 1) - ch <- &QueuedDataWriter{rows: rows} - return NewExecuteEvent(ch, nil) +func newReadyExecuteEvent(data []byte) *ResponseEvent { + ch := make(chan *executeResult, 1) + ch <- &executeResult{buf: bytes.NewBuffer(data)} + return NewExecuteEvent(ch) } // newErrorExecuteEvent creates an Execute event with an error result. func newErrorExecuteEvent(err error) *ResponseEvent { - ch := make(chan *QueuedDataWriter, 1) - result := &QueuedDataWriter{} - result.SetError(err) - ch <- result - return NewExecuteEvent(ch, nil) + ch := make(chan *executeResult, 1) + ch <- &executeResult{err: err} + return NewExecuteEvent(ch) } // TestResponseQueueBasicOperations tests enqueue and drain operations @@ -89,9 +88,9 @@ func TestDrainSyncNormalOperation(t *testing.T) { // Add some control events queue.Enqueue(NewParseCompleteEvent()) queue.Enqueue(NewBindCompleteEvent()) - queue.Enqueue(newReadyExecuteEvent([][]any{{"value1"}, {"value2"}})) + queue.Enqueue(newReadyExecuteEvent([]byte("some wire data"))) queue.Enqueue(NewParseCompleteEvent()) - queue.Enqueue(newReadyExecuteEvent([][]any{{"value3"}})) + queue.Enqueue(newReadyExecuteEvent([]byte("more wire data"))) // Drain all events events, err := queue.DrainSync(ctx) @@ -103,11 +102,9 @@ func TestDrainSyncNormalOperation(t *testing.T) { assert.Equal(t, ResponseBindComplete, events[1].Kind) assert.Equal(t, ResponseExecute, events[2].Kind) assert.NotNil(t, events[2].Result) - assert.Len(t, events[2].Result.rows, 2) assert.Equal(t, ResponseParseComplete, events[3].Kind) assert.Equal(t, ResponseExecute, events[4].Kind) assert.NotNil(t, events[4].Result) - assert.Len(t, events[4].Result.rows, 1) } // TestDrainSyncWithError tests early exit on error @@ -120,7 +117,7 @@ func TestDrainSyncWithError(t *testing.T) { // Add some control events queue.Enqueue(NewParseCompleteEvent()) queue.Enqueue(NewBindCompleteEvent()) - queue.Enqueue(newReadyExecuteEvent([][]any{{"success"}})) + queue.Enqueue(newReadyExecuteEvent([]byte("success"))) // Add Execute with error result testError := errors.New("query execution failed") @@ -142,7 +139,7 @@ func TestDrainSyncWithError(t *testing.T) { // First Execute should have its result assert.NotNil(t, events[2].Result) - assert.NoError(t, events[2].Result.GetError()) + assert.NoError(t, events[2].Result.err) } // TestDrainSyncContextCancellation tests context cancellation during drain diff --git a/wire.go b/wire.go index c179430..0cd1d94 100644 --- a/wire.go +++ b/wire.go @@ -141,6 +141,7 @@ type Server struct { TerminateConn CloseFn FlushConn FlushFn ParallelPipeline ParallelPipelineConfig + ErrorSanitizer func(error) error Version string ShutdownTimeout time.Duration typeExtension func(*pgtype.Map) @@ -215,18 +216,20 @@ func (srv *Server) Serve(listener net.Listener) error { } } +// newTypeMap creates a fresh pgtype.Map with any configured type extensions applied. +func (srv *Server) newTypeMap() *pgtype.Map { + m := pgtype.NewMap() + if srv.typeExtension != nil { + srv.typeExtension(m) + } + return m +} + func (srv *Server) serve(ctx context.Context, conn net.Conn) error { // Create a per-connection pgx Map to avoid concurrent map writes // Each connection gets its own type map instance to prevent race conditions // when multiple goroutines access the same map concurrently during query execution - connectionTypes := pgtype.NewMap() - - // Apply any type extension configured via ExtendTypes - if srv.typeExtension != nil { - srv.typeExtension(connectionTypes) - } - - ctx = setTypeInfo(ctx, connectionTypes) + ctx = setTypeInfo(ctx, srv.newTypeMap()) ctx = setRemoteAddress(ctx, conn.RemoteAddr()) defer conn.Close() //nolint:errcheck @@ -244,6 +247,7 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error { srv.logger.Debug("handshake successful, validating authentication") writer := buffer.NewWriter(srv.logger, conn) + writer.ErrorSanitizer = srv.ErrorSanitizer ctx, err = srv.readClientParameters(ctx, reader) if err != nil { return err diff --git a/writer.go b/writer.go index 16cb30c..9b51eab 100644 --- a/writer.go +++ b/writer.go @@ -4,18 +4,10 @@ import ( "context" "errors" - "github.com/jeroenrinzema/psql-wire/codes" - pgerror "github.com/jeroenrinzema/psql-wire/errors" "github.com/jeroenrinzema/psql-wire/pkg/buffer" "github.com/jeroenrinzema/psql-wire/pkg/types" ) -// Limit represents the maximum number of rows to be written. -// Zero denotes “no limit”. -type Limit uint32 - -const NoLimit Limit = 0 - // DataWriter represents a writer interface for writing columns and data rows // using the Postgres wire to the connected client. type DataWriter interface { @@ -26,10 +18,6 @@ type DataWriter interface { // values are encoded as NULL values. Row([]any) error - // Limit returns the maximum number of rows to be written passed within the - // wire protocol. A value of 0 indicates no limit. - Limit() uint32 - // Written returns the number of rows written to the client. Written() uint32 @@ -62,33 +50,21 @@ var ErrDataWritten = errors.New("data has already been written") // ErrClosedWriter is returned when the data writer has been closed. var ErrClosedWriter = errors.New("closed writer") -var ErrRowLimitExceeded = pgerror.WithCode(errors.New("row limit exceeded"), codes.ProgramLimitExceeded) - -// NewDataWriter constructs a new data writer using the given context and -// buffer. The returned writer should be handled with caution as it is not safe -// for concurrent use. Concurrent access to the same data without proper -// synchronization can result in unexpected behavior and data corruption. -func NewDataWriter(ctx context.Context, session *Session, columns Columns, formats []FormatCode, limit Limit, reader *buffer.Reader, writer *buffer.Writer) DataWriter { - return &dataWriter{ - ctx: ctx, - session: session, - columns: columns, - formats: formats, - limit: limit, - client: writer, - reader: reader, - } -} - -// dataWriter is a implementation of the DataWriter interface. +// dataWriter implements DataWriter for use inside an iter.Seq push +// iterator. Row encodes the row to the wire and then yields to the pull +// consumer for flow control. Complete writes CommandComplete to the wire. +// This approach allows portal suspension: when the pull consumer stops +// pulling (row limit reached), the handler goroutine blocks in yield +// until the next Execute. type dataWriter struct { ctx context.Context session *Session columns Columns formats []FormatCode - limit Limit client *buffer.Writer reader *buffer.Reader + yield func(struct{}) bool + tag *string closed bool written uint32 } @@ -97,27 +73,24 @@ func (writer *dataWriter) Columns() Columns { return writer.columns } -func (writer *dataWriter) Define(columns Columns) error { - if writer.closed { - return ErrClosedWriter - } - - writer.columns = columns - return writer.columns.Define(writer.ctx, writer.client, writer.formats) -} - func (writer *dataWriter) Row(values []any) error { if writer.closed { return ErrClosedWriter } - if writer.limit != 0 && Limit(writer.written) >= writer.limit { - return ErrRowLimitExceeded + err := writer.columns.Write(writer.ctx, writer.formats, writer.client, values) + if err != nil { + return err } writer.written++ - - return writer.columns.Write(writer.ctx, writer.formats, writer.client, values) + // The yield call "teleports" us back the next call of the pull consumer in + // Portal.execute. The yield function returns true when the pull consumer + // calls next again, and returns false when stop is called. + if !writer.yield(struct{}{}) { + return ErrSuspendedHandlerClosed + } + return nil } func (writer *dataWriter) CopyIn(format FormatCode) (*CopyReader, error) { @@ -129,7 +102,6 @@ func (writer *dataWriter) CopyIn(format FormatCode) (*CopyReader, error) { if err != nil { return nil, err } - return NewCopyReader(writer.session, writer.reader, writer.client, writer.columns), nil } @@ -146,10 +118,6 @@ func (writer *dataWriter) Empty() error { return nil } -func (writer *dataWriter) Limit() uint32 { - return uint32(writer.limit) -} - func (writer *dataWriter) Written() uint32 { return writer.written } @@ -159,14 +127,8 @@ func (writer *dataWriter) Complete(description string) error { return ErrClosedWriter } - if writer.written == 0 && writer.columns != nil { - err := writer.Empty() - if err != nil { - return err - } - } - defer writer.close() + *writer.tag = description return commandComplete(writer.client, description) } @@ -183,3 +145,9 @@ func commandComplete(writer *buffer.Writer, description string) error { writer.AddNullTerminate() return writer.End() } + +// ErrSuspendedHandlerClosed is returned from DataWriter.Row when a suspended +// portal is closed (or re-bound) before the handler finished producing rows. +// Handlers can check for this error to distinguish graceful portal teardown +// from real failures and skip error logging. +var ErrSuspendedHandlerClosed = errors.New("suspended handler closed")