Skip to content

Commit 64e53ce

Browse files
committed
Implement cache cleanup behaviour
This cleans up statement and portal caches when receiving a Close message, and cleans up the portal cache when a statement in the cache gets overwritten.
1 parent 14d10e2 commit 64e53ce

6 files changed

Lines changed: 253 additions & 9 deletions

File tree

cache.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (*Stat
6060
return stmt, nil
6161
}
6262

63+
func (cache *DefaultStatementCache) Delete(ctx context.Context, name string) error {
64+
cache.mu.Lock()
65+
defer cache.mu.Unlock()
66+
delete(cache.statements, name)
67+
return nil
68+
}
69+
6370
func (cache *DefaultStatementCache) Close() {}
6471

6572
type Portal struct {
@@ -134,4 +141,22 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, limit
134141
return portal.statement.fn(ctx, NewDataWriter(ctx, session, portal.statement.columns, portal.formats, limit, reader, writer), portal.parameters)
135142
}
136143

144+
func (cache *DefaultPortalCache) Delete(ctx context.Context, name string) error {
145+
cache.mu.Lock()
146+
defer cache.mu.Unlock()
147+
delete(cache.portals, name)
148+
return nil
149+
}
150+
151+
func (cache *DefaultPortalCache) DeleteByStatement(ctx context.Context, stmt *Statement) error {
152+
cache.mu.Lock()
153+
defer cache.mu.Unlock()
154+
for name, portal := range cache.portals {
155+
if portal.statement == stmt {
156+
delete(cache.portals, name)
157+
}
158+
}
159+
return nil
160+
}
161+
137162
func (cache *DefaultPortalCache) Close() {}

command.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ type Session struct {
6363

6464
// inExtendedQuery is true when the current message being handled is an
6565
// extended query protocol message (Parse, Bind, Describe, Execute, Close,
66-
// Flush, Sync). This lets Session.ErrorCode behave correctly for both
66+
// Flush, Sync). This lets Session.WriteError behave correctly for both
6767
// protocols.
6868
inExtendedQuery bool
6969

@@ -373,6 +373,18 @@ func (srv *Session) handleParse(ctx context.Context, reader *buffer.Reader, writ
373373
// `reader.GetUint32()`
374374
}
375375

376+
existing, err := srv.Statements.Get(ctx, name)
377+
if err != nil {
378+
if srv.ParallelPipeline.Enabled {
379+
return srv.drainQueueAndWriteError(ctx, writer, err)
380+
}
381+
return srv.WriteError(writer, err)
382+
}
383+
384+
if existing != nil {
385+
srv.Portals.DeleteByStatement(ctx, existing) //nolint:errcheck
386+
}
387+
376388
if srv.ParallelPipeline.Enabled {
377389
return srv.parsePipelined(ctx, writer, name, query)
378390
}
@@ -721,6 +733,17 @@ func (srv *Session) handleClose(ctx context.Context, reader *buffer.Reader, writ
721733

722734
srv.logger.Debug("incoming close request", slog.String("type", string(d[0])), slog.String("name", name))
723735

736+
switch types.DescribeMessage(d[0]) {
737+
case types.DescribeStatement:
738+
stmt, _ := srv.Statements.Get(ctx, name)
739+
srv.Statements.Delete(ctx, name) //nolint:errcheck
740+
if stmt != nil {
741+
srv.Portals.DeleteByStatement(ctx, stmt) //nolint:errcheck
742+
}
743+
case types.DescribePortal:
744+
srv.Portals.Delete(ctx, name) //nolint:errcheck
745+
}
746+
724747
if srv.ParallelPipeline.Enabled {
725748
srv.ResponseQueue.Enqueue(NewCloseCompleteEvent())
726749
return nil

command_close_test.go

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,86 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16-
func TestHandleClose_Statement(t *testing.T) {
16+
func TestHandleClose_StatementRemovesFromCache(t *testing.T) {
1717
t.Parallel()
1818

1919
ctx := context.Background()
2020
logger := slogt.New(t)
2121

22+
stmtCache := &DefaultStatementCache{}
23+
portalCache := &DefaultPortalCache{}
24+
25+
stmt := &PreparedStatement{
26+
fn: func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
27+
return writer.Complete("OK")
28+
},
29+
}
30+
31+
err := stmtCache.Set(ctx, "stmt1", stmt)
32+
require.NoError(t, err)
33+
34+
// Bind two portals to the same statement
35+
cached, err := stmtCache.Get(ctx, "stmt1")
36+
require.NoError(t, err)
37+
require.NotNil(t, cached)
38+
39+
err = portalCache.Bind(ctx, "p1", cached, nil, nil)
40+
require.NoError(t, err)
41+
err = portalCache.Bind(ctx, "p2", cached, nil, nil)
42+
require.NoError(t, err)
43+
44+
// Bind a portal to a different statement so we can verify it survives
45+
otherStmt := &PreparedStatement{
46+
fn: func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
47+
return writer.Complete("OK")
48+
},
49+
}
50+
err = stmtCache.Set(ctx, "other", otherStmt)
51+
require.NoError(t, err)
52+
otherCached, err := stmtCache.Get(ctx, "other")
53+
require.NoError(t, err)
54+
err = portalCache.Bind(ctx, "p3", otherCached, nil, nil)
55+
require.NoError(t, err)
56+
2257
session := &Session{
2358
Server: &Server{logger: logger},
24-
Statements: &DefaultStatementCache{},
25-
Portals: &DefaultPortalCache{},
59+
Statements: stmtCache,
60+
Portals: portalCache,
2661
}
2762

2863
outBuf := &bytes.Buffer{}
2964
writer := buffer.NewWriter(logger, outBuf)
3065

31-
reader := mock.NewCloseReader(t, logger, 'S', "stmt1")
66+
err = session.handleClose(ctx, mock.NewCloseReader(t, logger, 'S', "stmt1"), writer)
67+
require.NoError(t, err)
3268

33-
err := session.handleClose(ctx, reader, writer)
69+
// Statement should be removed
70+
got, err := stmtCache.Get(ctx, "stmt1")
71+
require.NoError(t, err)
72+
assert.Nil(t, got)
73+
74+
// Portals bound to the closed statement should be removed
75+
p1, err := portalCache.Get(ctx, "p1")
76+
require.NoError(t, err)
77+
assert.Nil(t, p1)
78+
79+
p2, err := portalCache.Get(ctx, "p2")
3480
require.NoError(t, err)
81+
assert.Nil(t, p2)
3582

83+
// Portal bound to a different statement should still exist
84+
p3, err := portalCache.Get(ctx, "p3")
85+
require.NoError(t, err)
86+
assert.NotNil(t, p3)
87+
88+
// CloseComplete should still be sent
3689
responseReader := mock.NewReader(t, outBuf)
3790
msgType, _, err := responseReader.ReadTypedMsg()
3891
require.NoError(t, err)
3992
assert.Equal(t, types.ServerCloseComplete, msgType)
4093
}
4194

42-
func TestHandleClose_Portal(t *testing.T) {
95+
func TestHandleClose_NonexistentNameSendsCloseComplete(t *testing.T) {
4396
t.Parallel()
4497

4598
ctx := context.Background()
@@ -54,10 +107,70 @@ func TestHandleClose_Portal(t *testing.T) {
54107
outBuf := &bytes.Buffer{}
55108
writer := buffer.NewWriter(logger, outBuf)
56109

57-
reader := mock.NewCloseReader(t, logger, 'P', "portal1")
110+
// Close a statement that doesn't exist
111+
err := session.handleClose(ctx, mock.NewCloseReader(t, logger, 'S', "nonexistent"), writer)
112+
require.NoError(t, err)
58113

59-
err := session.handleClose(ctx, reader, writer)
114+
responseReader := mock.NewReader(t, outBuf)
115+
msgType, _, err := responseReader.ReadTypedMsg()
116+
require.NoError(t, err)
117+
assert.Equal(t, types.ServerCloseComplete, msgType)
118+
119+
// Close a portal that doesn't exist
120+
outBuf.Reset()
121+
err = session.handleClose(ctx, mock.NewCloseReader(t, logger, 'P', "nonexistent"), writer)
122+
require.NoError(t, err)
123+
124+
responseReader = mock.NewReader(t, outBuf)
125+
msgType, _, err = responseReader.ReadTypedMsg()
126+
require.NoError(t, err)
127+
assert.Equal(t, types.ServerCloseComplete, msgType)
128+
}
129+
130+
func TestHandleClose_PortalRemovesFromCache(t *testing.T) {
131+
t.Parallel()
132+
133+
ctx := context.Background()
134+
logger := slogt.New(t)
135+
136+
stmtCache := &DefaultStatementCache{}
137+
portalCache := &DefaultPortalCache{}
138+
139+
stmt := &PreparedStatement{
140+
fn: func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
141+
return writer.Complete("OK")
142+
},
143+
}
144+
145+
err := stmtCache.Set(ctx, "stmt1", stmt)
146+
require.NoError(t, err)
147+
cached, err := stmtCache.Get(ctx, "stmt1")
148+
require.NoError(t, err)
149+
150+
err = portalCache.Bind(ctx, "portal1", cached, nil, nil)
151+
require.NoError(t, err)
152+
153+
session := &Session{
154+
Server: &Server{logger: logger},
155+
Statements: stmtCache,
156+
Portals: portalCache,
157+
}
158+
159+
outBuf := &bytes.Buffer{}
160+
writer := buffer.NewWriter(logger, outBuf)
161+
162+
err = session.handleClose(ctx, mock.NewCloseReader(t, logger, 'P', "portal1"), writer)
163+
require.NoError(t, err)
164+
165+
// Portal should be removed
166+
portal, err := portalCache.Get(ctx, "portal1")
167+
require.NoError(t, err)
168+
assert.Nil(t, portal)
169+
170+
// Statement should still exist
171+
got, err := stmtCache.Get(ctx, "stmt1")
60172
require.NoError(t, err)
173+
assert.NotNil(t, got)
61174

62175
responseReader := mock.NewReader(t, outBuf)
63176
msgType, _, err := responseReader.ReadTypedMsg()

command_parse_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,71 @@ func TestHandleParse_ParallelPipeline_Error(t *testing.T) {
148148
require.NoError(t, err)
149149
assert.Equal(t, types.ServerErrorResponse, msgType)
150150
}
151+
152+
func TestHandleParse_OverwriteRemovesRelatedPortals(t *testing.T) {
153+
t.Parallel()
154+
155+
ctx := context.Background()
156+
typeMap := pgtype.NewMap()
157+
ctx = setTypeInfo(ctx, typeMap)
158+
logger := slogt.New(t)
159+
160+
mockParse := func(ctx context.Context, query string) (PreparedStatements, error) {
161+
return PreparedStatements{NewStatement(
162+
func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
163+
return writer.Complete("SELECT 1")
164+
},
165+
)}, nil
166+
}
167+
168+
portalCache := &DefaultPortalCache{}
169+
session := &Session{
170+
Server: &Server{
171+
logger: logger,
172+
parse: mockParse,
173+
},
174+
Statements: &DefaultStatementCache{},
175+
Portals: portalCache,
176+
}
177+
178+
outBuf := &bytes.Buffer{}
179+
writer := buffer.NewWriter(logger, outBuf)
180+
181+
// Parse a statement and bind two portals to it
182+
err := session.handleParse(ctx, mock.NewParseReader(t, logger, "s1", "SELECT 1", 0), writer)
183+
require.NoError(t, err)
184+
err = session.handleBind(ctx, mock.NewBindReader(t, logger, "p1", "s1", 0, 0, 0), writer)
185+
require.NoError(t, err)
186+
err = session.handleBind(ctx, mock.NewBindReader(t, logger, "p2", "s1", 0, 0, 0), writer)
187+
require.NoError(t, err)
188+
189+
// Parse a different statement and bind a portal to it (should survive)
190+
err = session.handleParse(ctx, mock.NewParseReader(t, logger, "s2", "SELECT 2", 0), writer)
191+
require.NoError(t, err)
192+
err = session.handleBind(ctx, mock.NewBindReader(t, logger, "p3", "s2", 0, 0, 0), writer)
193+
require.NoError(t, err)
194+
195+
// Verify all portals exist
196+
for _, name := range []string{"p1", "p2", "p3"} {
197+
portal, err := portalCache.Get(ctx, name)
198+
require.NoError(t, err)
199+
require.NotNil(t, portal, "portal %s should exist before overwrite", name)
200+
}
201+
202+
// Re-parse s1 — portals p1 and p2 should be removed
203+
err = session.handleParse(ctx, mock.NewParseReader(t, logger, "s1", "SELECT 3", 0), writer)
204+
require.NoError(t, err)
205+
206+
p1, err := portalCache.Get(ctx, "p1")
207+
require.NoError(t, err)
208+
assert.Nil(t, p1, "portal bound to overwritten statement should be removed")
209+
210+
p2, err := portalCache.Get(ctx, "p2")
211+
require.NoError(t, err)
212+
assert.Nil(t, p2, "portal bound to overwritten statement should be removed")
213+
214+
// p3 is bound to s2, should still exist
215+
p3, err := portalCache.Get(ctx, "p3")
216+
require.NoError(t, err)
217+
assert.NotNil(t, p3, "portal bound to a different statement should survive")
218+
}

error_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,12 @@ func TestDiscardUntilSync(t *testing.T) {
289289
// Second cycle: Close, Sync (deallocate the failed statement)
290290
err = session.handleClose(ctx, mock.NewCloseReader(t, logger, 'S', "stmt1"), writer)
291291
require.NoError(t, err)
292+
293+
// The statement should be removed from the cache after Close
294+
stmt, err := session.Statements.Get(ctx, "stmt1")
295+
require.NoError(t, err)
296+
assert.Nil(t, stmt, "statement should be removed from cache after Close")
297+
292298
err = session.handleSync(ctx, writer)
293299
require.NoError(t, err)
294300

options.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ type StatementCache interface {
8585
// Get attempts to get the prepared statement for the given name. An error
8686
// is returned when no statement has been found.
8787
Get(ctx context.Context, name string) (*Statement, error)
88+
// Delete removes the prepared statement with the given name. Deleting a
89+
// nonexistent name is not an error.
90+
Delete(ctx context.Context, name string) error
8891
// Close is called at the end of a connection. Close releases all resources
8992
// held by the statement cache.
9093
Close()
@@ -101,6 +104,12 @@ type PortalCache interface {
101104
Get(ctx context.Context, name string) (*Portal, error)
102105
// Execute executes the prepared statement with the given name and parameters.
103106
Execute(ctx context.Context, name string, limit Limit, reader *buffer.Reader, writer *buffer.Writer) error
107+
// Delete removes the portal with the given name. Deleting a nonexistent
108+
// name is not an error.
109+
Delete(ctx context.Context, name string) error
110+
// DeleteByStatement removes all portals that were bound to the given
111+
// statement.
112+
DeleteByStatement(ctx context.Context, stmt *Statement) error
104113
// Close is called at the end of a connection. Close releases all resources
105114
// held by the portal cache.
106115
Close()

0 commit comments

Comments
 (0)