Skip to content

Commit 876091c

Browse files
committed
protocol negotiation test
1 parent d535afe commit 876091c

File tree

3 files changed

+213
-1
lines changed

3 files changed

+213
-1
lines changed

conn_test.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
10541054
addr string
10551055
protocol uint8
10561056
recvHook func(*framer)
1057+
1058+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1059+
dontFailOnProtocolMismatch bool
10571060
}
10581061

10591062
func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer {
@@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS
10781081
cancel: cancel,
10791082

10801083
onRecv: nts.recvHook,
1084+
1085+
customRequestHandler: nts.customRequestHandler,
1086+
dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
10811087
}
10821088

10831089
go srv.closeWatch()
@@ -1142,6 +1148,10 @@ type TestServer struct {
11421148

11431149
// onRecv is a hook point for tests, called in receive loop.
11441150
onRecv func(*framer)
1151+
1152+
// customRequestHandler allows overriding the default request handling for testing purposes.
1153+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1154+
dontFailOnProtocolMismatch bool
11451155
}
11461156

11471157
func (srv *TestServer) closeWatch() {
@@ -1221,7 +1231,22 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
12211231
srv.errorLocked("process frame with a nil header")
12221232
return
12231233
}
1224-
respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
1234+
respFrame := newFramer(nil, byte(head.version), GlobalTypes)
1235+
1236+
if srv.customRequestHandler != nil {
1237+
if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil {
1238+
srv.errorLocked(err)
1239+
return
1240+
}
1241+
respFrame.buf[0] |= 0x80
1242+
if err := respFrame.finish(); err != nil {
1243+
srv.errorLocked(err)
1244+
}
1245+
if err := respFrame.writeTo(conn); err != nil {
1246+
srv.errorLocked(err)
1247+
}
1248+
return
1249+
}
12251250

12261251
switch head.op {
12271252
case opStartup:

frame.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,6 +2370,14 @@ func (f *framer) writeStringMap(m map[string]string) {
23702370
}
23712371
}
23722372

2373+
func (f *framer) writeStringMultiMap(m map[string][]string) {
2374+
f.writeShort(uint16(len(m)))
2375+
for k, v := range m {
2376+
f.writeString(k)
2377+
f.writeStringList(v)
2378+
}
2379+
}
2380+
23732381
func (f *framer) writeBytesMap(m map[string][]byte) {
23742382
f.writeShort(uint16(len(m)))
23752383
for k, v := range m {

protocol_negotiation_test.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//go:build all || unit
2+
// +build all unit
3+
4+
package gocql
5+
6+
import (
7+
"context"
8+
"encoding/binary"
9+
"fmt"
10+
"slices"
11+
"testing"
12+
"time"
13+
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
type requestHandlerForProtocolNegotiationTest struct {
18+
supportedProtocolVersions []protoVersion
19+
supportedBetaProtocols []protoVersion
20+
}
21+
22+
func (r *requestHandlerForProtocolNegotiationTest) supportsBetaProtocol(version protoVersion) bool {
23+
return slices.Contains(r.supportedBetaProtocols, version)
24+
}
25+
26+
func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version protoVersion) bool {
27+
return slices.Contains(r.supportedProtocolVersions, version)
28+
}
29+
30+
func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header *frameHeader) bool {
31+
return header.flags&flagBetaProtocol == flagBetaProtocol
32+
}
33+
34+
func (r *requestHandlerForProtocolNegotiationTest) createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string {
35+
return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), but USE_BETA flag is unset", version, version)
36+
}
37+
38+
func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, reqFrame, respFrame *framer) error {
39+
// If a client uses beta protocol, but the USE_BETA flag is not set, we respond with an error
40+
if r.supportsBetaProtocol(reqFrame.header.version) && !r.hasBetaFlag(reqFrame.header) {
41+
respFrame.writeHeader(0, opError, reqFrame.header.stream)
42+
respFrame.writeInt(ErrCodeProtocol)
43+
respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version))
44+
return nil
45+
}
46+
47+
// if a client uses an unsupported protocol version, we respond with an error
48+
if !r.supportsProtocol(reqFrame.header.version) {
49+
respFrame.writeHeader(0, opError, reqFrame.header.stream)
50+
respFrame.writeInt(ErrCodeProtocol)
51+
respFrame.writeString(fmt.Sprintf("NEGOTITATION TEST: Unsupported protocol version %d", reqFrame.header.version))
52+
return nil
53+
}
54+
55+
stream := reqFrame.header.stream
56+
57+
switch reqFrame.header.op {
58+
case opStartup, opRegister:
59+
respFrame.writeHeader(0, opReady, stream)
60+
case opOptions:
61+
respFrame.writeHeader(0, opSupported, stream)
62+
var supportedVersionsWithDesc []string
63+
for _, supportedVersion := range r.supportedProtocolVersions {
64+
supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion, supportedVersion))
65+
}
66+
for _, betaProtocol := range r.supportedBetaProtocols {
67+
supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol, betaProtocol))
68+
}
69+
supported := map[string][]string{
70+
"PROTOCOL_VERSIONS": supportedVersionsWithDesc,
71+
}
72+
respFrame.writeStringMultiMap(supported)
73+
case opQuery:
74+
respFrame.writeHeader(0, opResult, stream)
75+
respFrame.writeInt(resultKindRows)
76+
respFrame.writeInt(int32(flagGlobalTableSpec))
77+
respFrame.writeInt(1)
78+
respFrame.writeString("system")
79+
respFrame.writeString("local")
80+
respFrame.writeString("rack")
81+
respFrame.writeShort(uint16(TypeVarchar))
82+
respFrame.writeInt(1)
83+
respFrame.writeInt(int32(len("rack-1")))
84+
respFrame.writeString("rack-1")
85+
case opPrepare:
86+
// This doesn't really make any sense, but it's enough to test the protocol negotiation
87+
respFrame.writeHeader(0, opResult, stream)
88+
respFrame.writeInt(resultKindPrepared)
89+
// <id>
90+
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 111))
91+
if respFrame.proto >= protoVersion5 {
92+
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222))
93+
}
94+
// <metadata>
95+
respFrame.writeInt(0) // <flags>
96+
respFrame.writeInt(0) // <columns_count>
97+
if reqFrame.header.version >= protoVersion4 {
98+
respFrame.writeInt(0) // <pk_count>
99+
}
100+
// <result_metadata>
101+
respFrame.writeInt(int32(flagGlobalTableSpec)) // <flags>
102+
respFrame.writeInt(1) // <columns_count>
103+
// <global_table_spec>
104+
respFrame.writeString("system")
105+
respFrame.writeString("keyspaces")
106+
// <col_spec_0>
107+
respFrame.writeString("col0") // <name>
108+
respFrame.writeShort(uint16(TypeBoolean)) // <type>
109+
case opExecute:
110+
// This doesn't really make any sense, but it's enough to test the protocol negotiation
111+
respFrame.writeHeader(0, opResult, stream)
112+
respFrame.writeInt(resultKindRows)
113+
// <metadata>
114+
respFrame.writeInt(0) // <flags>
115+
respFrame.writeInt(0) // <columns_count>
116+
// <rows_count>
117+
respFrame.writeInt(0)
118+
}
119+
120+
return nil
121+
}
122+
123+
func TestProtocolNegotiation(t *testing.T) {
124+
testCases := []struct {
125+
name string
126+
supportedVersions []protoVersion
127+
supportedBetaVersions []protoVersion
128+
expectedVersion protoVersion
129+
}{
130+
{
131+
name: "all supported versions",
132+
supportedVersions: []protoVersion{protoVersion3, protoVersion4, protoVersion5},
133+
expectedVersion: protoVersion5,
134+
},
135+
{
136+
name: "v5-beta is supported",
137+
supportedVersions: []protoVersion{protoVersion3, protoVersion4},
138+
supportedBetaVersions: []protoVersion{protoVersion5},
139+
expectedVersion: protoVersion4,
140+
},
141+
{
142+
name: "v5 is unsupported",
143+
supportedVersions: []protoVersion{protoVersion3, protoVersion4},
144+
expectedVersion: protoVersion4,
145+
},
146+
}
147+
148+
for _, tc := range testCases {
149+
t.Run(tc.name, func(t *testing.T) {
150+
handler := &requestHandlerForProtocolNegotiationTest{
151+
supportedProtocolVersions: tc.supportedVersions,
152+
supportedBetaProtocols: tc.supportedBetaVersions,
153+
}
154+
155+
srv := newTestServerOpts{
156+
addr: "127.0.0.1:0",
157+
protocol: 5,
158+
customRequestHandler: handler.handle,
159+
dontFailOnProtocolMismatch: true,
160+
}.newServer(t, context.Background())
161+
162+
go srv.serve()
163+
defer srv.Stop()
164+
165+
cluster := NewCluster(srv.Address)
166+
cluster.Compressor = nil
167+
cluster.ProtoVersion = 0
168+
cluster.Logger = NewLogger(LogLevelDebug)
169+
cluster.ConnectTimeout = time.Hour
170+
cluster.Timeout = time.Hour
171+
cluster.DisableInitialHostLookup = true
172+
173+
s, err := cluster.CreateSession()
174+
require.NoError(t, err)
175+
176+
require.Equal(t, tc.expectedVersion, protoVersion(s.cfg.ProtoVersion))
177+
})
178+
}
179+
}

0 commit comments

Comments
 (0)