@@ -1172,9 +1172,26 @@ func (srv *TestServer) serve() {
11721172 }
11731173
11741174 go func (conn net.Conn ) {
1175+ var startupCompleted bool
1176+ var useProtoV5 bool
1177+
11751178 defer conn .Close ()
11761179 for ! srv .isClosed () {
1177- framer , err := srv .readFrame (conn )
1180+ var reader io.Reader = conn
1181+
1182+ if useProtoV5 && startupCompleted {
1183+ frame , _ , err := readUncompressedSegment (conn )
1184+ if err != nil {
1185+ if errors .Is (err , io .EOF ) {
1186+ return
1187+ }
1188+ srv .errorLocked (err )
1189+ return
1190+ }
1191+ reader = bytes .NewReader (frame )
1192+ }
1193+
1194+ framer , err := srv .readFrame (reader )
11781195 if err != nil {
11791196 if err == io .EOF {
11801197 return
@@ -1187,7 +1204,7 @@ func (srv *TestServer) serve() {
11871204 srv .onRecv (framer )
11881205 }
11891206
1190- go srv .process (conn , framer )
1207+ srv .process (conn , framer , & useProtoV5 , & startupCompleted )
11911208 }
11921209 }(conn )
11931210 }
@@ -1225,7 +1242,7 @@ func (srv *TestServer) errorLocked(err interface{}) {
12251242 srv .t .Error (err )
12261243}
12271244
1228- func (srv * TestServer ) process (conn net.Conn , reqFrame * framer ) {
1245+ func (srv * TestServer ) process (conn net.Conn , reqFrame * framer , useProtoV5 , startupCompleted * bool ) {
12291246 head := reqFrame .header
12301247 if head == nil {
12311248 srv .errorLocked ("process frame with a nil header" )
@@ -1238,14 +1255,8 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
12381255 srv .errorLocked (err )
12391256 return
12401257 }
1241- respFrame .buf [0 ] |= 0x80
1242- if err := respFrame .finish (); err != nil {
1243- srv .errorLocked (err )
1244- }
1245- if err := respFrame .writeTo (conn ); err != nil {
1246- srv .errorLocked (err )
1247- }
1248- return
1258+ // Dont like this but...
1259+ goto finish
12491260 }
12501261
12511262 switch head .op {
@@ -1437,34 +1448,54 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
14371448 respFrame .writeString ("not supported" )
14381449 }
14391450
1440- respFrame .buf [0 ] = srv .protocol | 0x80
1451+ finish:
1452+
1453+ respFrame .buf [0 ] |= 0x80
14411454
14421455 if err := respFrame .finish (); err != nil {
14431456 srv .errorLocked (err )
14441457 }
14451458
1446- if err := respFrame .writeTo (conn ); err != nil {
1447- srv .errorLocked (err )
1459+ if * useProtoV5 && * startupCompleted {
1460+ segment , err := newUncompressedSegment (respFrame .buf , true )
1461+ if err == nil {
1462+ _ , err = conn .Write (segment )
1463+ }
1464+ if err != nil {
1465+ srv .errorLocked (err )
1466+ return
1467+ }
1468+ } else {
1469+ if err := respFrame .writeTo (conn ); err != nil {
1470+ srv .errorLocked (err )
1471+ }
1472+
1473+ if reqFrame .header .op == opStartup {
1474+ * startupCompleted = true
1475+ if head .version == protoVersion5 {
1476+ * useProtoV5 = true
1477+ }
1478+ }
14481479 }
14491480}
14501481
1451- func (srv * TestServer ) readFrame (conn net. Conn ) (* framer , error ) {
1482+ func (srv * TestServer ) readFrame (reader io. Reader ) (* framer , error ) {
14521483 buf := make ([]byte , srv .headerSize )
1453- head , err := readHeader (conn , buf )
1484+ head , err := readHeader (reader , buf )
14541485 if err != nil {
14551486 return nil , err
14561487 }
14571488 framer := newFramer (nil , srv .protocol , GlobalTypes )
14581489
1459- err = framer .readFrame (conn , & head )
1490+ err = framer .readFrame (reader , & head )
14601491 if err != nil {
14611492 return nil , err
14621493 }
14631494
14641495 // should be a request frame
14651496 if head .version .response () {
14661497 return nil , fmt .Errorf ("expected to read a request frame got version: %v" , head .version )
1467- } else if head .version .version () != srv .protocol {
1498+ } else if ! srv . dontFailOnProtocolMismatch && head .version .version () != srv .protocol {
14681499 return nil , fmt .Errorf ("expected to read protocol version 0x%x got 0x%x" , srv .protocol , head .version .version ())
14691500 }
14701501
0 commit comments