@@ -13,33 +13,86 @@ import (
1313 "github.com/stretchr/testify/require"
1414)
1515
16- func TestHandleClose_Statement (t * testing.T ) {
16+ func TestHandleClose_StatementRemovesFromCache (t * testing.T ) {
1717 t .Parallel ()
1818
1919 ctx := context .Background ()
2020 logger := slogt .New (t )
2121
22+ stmtCache := & DefaultStatementCache {}
23+ portalCache := & DefaultPortalCache {}
24+
25+ stmt := & PreparedStatement {
26+ fn : func (ctx context.Context , writer DataWriter , parameters []Parameter ) error {
27+ return writer .Complete ("OK" )
28+ },
29+ }
30+
31+ err := stmtCache .Set (ctx , "stmt1" , stmt )
32+ require .NoError (t , err )
33+
34+ // Bind two portals to the same statement
35+ cached , err := stmtCache .Get (ctx , "stmt1" )
36+ require .NoError (t , err )
37+ require .NotNil (t , cached )
38+
39+ err = portalCache .Bind (ctx , "p1" , cached , nil , nil )
40+ require .NoError (t , err )
41+ err = portalCache .Bind (ctx , "p2" , cached , nil , nil )
42+ require .NoError (t , err )
43+
44+ // Bind a portal to a different statement so we can verify it survives
45+ otherStmt := & PreparedStatement {
46+ fn : func (ctx context.Context , writer DataWriter , parameters []Parameter ) error {
47+ return writer .Complete ("OK" )
48+ },
49+ }
50+ err = stmtCache .Set (ctx , "other" , otherStmt )
51+ require .NoError (t , err )
52+ otherCached , err := stmtCache .Get (ctx , "other" )
53+ require .NoError (t , err )
54+ err = portalCache .Bind (ctx , "p3" , otherCached , nil , nil )
55+ require .NoError (t , err )
56+
2257 session := & Session {
2358 Server : & Server {logger : logger },
24- Statements : & DefaultStatementCache {} ,
25- Portals : & DefaultPortalCache {} ,
59+ Statements : stmtCache ,
60+ Portals : portalCache ,
2661 }
2762
2863 outBuf := & bytes.Buffer {}
2964 writer := buffer .NewWriter (logger , outBuf )
3065
31- reader := mock .NewCloseReader (t , logger , 'S' , "stmt1" )
66+ err = session .handleClose (ctx , mock .NewCloseReader (t , logger , 'S' , "stmt1" ), writer )
67+ require .NoError (t , err )
3268
33- err := session .handleClose (ctx , reader , writer )
69+ // Statement should be removed
70+ got , err := stmtCache .Get (ctx , "stmt1" )
71+ require .NoError (t , err )
72+ assert .Nil (t , got )
73+
74+ // Portals bound to the closed statement should be removed
75+ p1 , err := portalCache .Get (ctx , "p1" )
76+ require .NoError (t , err )
77+ assert .Nil (t , p1 )
78+
79+ p2 , err := portalCache .Get (ctx , "p2" )
3480 require .NoError (t , err )
81+ assert .Nil (t , p2 )
3582
83+ // Portal bound to a different statement should still exist
84+ p3 , err := portalCache .Get (ctx , "p3" )
85+ require .NoError (t , err )
86+ assert .NotNil (t , p3 )
87+
88+ // CloseComplete should still be sent
3689 responseReader := mock .NewReader (t , outBuf )
3790 msgType , _ , err := responseReader .ReadTypedMsg ()
3891 require .NoError (t , err )
3992 assert .Equal (t , types .ServerCloseComplete , msgType )
4093}
4194
42- func TestHandleClose_Portal (t * testing.T ) {
95+ func TestHandleClose_NonexistentNameSendsCloseComplete (t * testing.T ) {
4396 t .Parallel ()
4497
4598 ctx := context .Background ()
@@ -54,10 +107,70 @@ func TestHandleClose_Portal(t *testing.T) {
54107 outBuf := & bytes.Buffer {}
55108 writer := buffer .NewWriter (logger , outBuf )
56109
57- reader := mock .NewCloseReader (t , logger , 'P' , "portal1" )
110+ // Close a statement that doesn't exist
111+ err := session .handleClose (ctx , mock .NewCloseReader (t , logger , 'S' , "nonexistent" ), writer )
112+ require .NoError (t , err )
58113
59- err := session .handleClose (ctx , reader , writer )
114+ responseReader := mock .NewReader (t , outBuf )
115+ msgType , _ , err := responseReader .ReadTypedMsg ()
116+ require .NoError (t , err )
117+ assert .Equal (t , types .ServerCloseComplete , msgType )
118+
119+ // Close a portal that doesn't exist
120+ outBuf .Reset ()
121+ err = session .handleClose (ctx , mock .NewCloseReader (t , logger , 'P' , "nonexistent" ), writer )
122+ require .NoError (t , err )
123+
124+ responseReader = mock .NewReader (t , outBuf )
125+ msgType , _ , err = responseReader .ReadTypedMsg ()
126+ require .NoError (t , err )
127+ assert .Equal (t , types .ServerCloseComplete , msgType )
128+ }
129+
130+ func TestHandleClose_PortalRemovesFromCache (t * testing.T ) {
131+ t .Parallel ()
132+
133+ ctx := context .Background ()
134+ logger := slogt .New (t )
135+
136+ stmtCache := & DefaultStatementCache {}
137+ portalCache := & DefaultPortalCache {}
138+
139+ stmt := & PreparedStatement {
140+ fn : func (ctx context.Context , writer DataWriter , parameters []Parameter ) error {
141+ return writer .Complete ("OK" )
142+ },
143+ }
144+
145+ err := stmtCache .Set (ctx , "stmt1" , stmt )
146+ require .NoError (t , err )
147+ cached , err := stmtCache .Get (ctx , "stmt1" )
148+ require .NoError (t , err )
149+
150+ err = portalCache .Bind (ctx , "portal1" , cached , nil , nil )
151+ require .NoError (t , err )
152+
153+ session := & Session {
154+ Server : & Server {logger : logger },
155+ Statements : stmtCache ,
156+ Portals : portalCache ,
157+ }
158+
159+ outBuf := & bytes.Buffer {}
160+ writer := buffer .NewWriter (logger , outBuf )
161+
162+ err = session .handleClose (ctx , mock .NewCloseReader (t , logger , 'P' , "portal1" ), writer )
163+ require .NoError (t , err )
164+
165+ // Portal should be removed
166+ portal , err := portalCache .Get (ctx , "portal1" )
167+ require .NoError (t , err )
168+ assert .Nil (t , portal )
169+
170+ // Statement should still exist
171+ got , err := stmtCache .Get (ctx , "stmt1" )
60172 require .NoError (t , err )
173+ assert .NotNil (t , got )
61174
62175 responseReader := mock .NewReader (t , outBuf )
63176 msgType , _ , err := responseReader .ReadTypedMsg ()
0 commit comments