Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,105 @@ func TestClientDeadlineHandling(t *testing.T) {
})
}

// TestClientBidiOverHTTP1 tests bidi streaming over real TCP connections
// (not in-memory pipes) to verify the behavior with real kernel socket buffers.
// The server is a plain http.Server without TLS or HTTP/2 configuration, so
// all connections are HTTP/1.1.
func TestClientBidiOverHTTP1(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{checkMetadata: true}))

listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
// Plain http.Server without Protocols or TLS: HTTP/1.1 only.
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
go server.Serve(listener) //nolint:errcheck
t.Cleanup(func() { server.Close() })

addr := listener.Addr().String()
transport := &http.Transport{
ForceAttemptHTTP2: false,
DisableKeepAlives: true,
}
httpClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(httpClient, "http://"+addr)

t.Run("full_duplex", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

stream := client.CumSum(ctx)
for _, el := range expectedHeaderValues {
stream.RequestHeader().Add(clientHeader, el)
}
// Interleave send and receive -- true full-duplex.
send := []int64{3, 5, 1}
expect := []int64{3, 8, 9}
for i, num := range send {
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: num}))
resp, err := stream.Receive()
assert.Nil(t, err)
assert.NotNil(t, resp)
assert.Equal(t, resp.GetSum(), expect[i])
}
assert.Nil(t, stream.CloseRequest())
_, err := stream.Receive()
assert.ErrorIs(t, err, io.EOF)
assert.Nil(t, stream.CloseResponse())
})

t.Run("half_duplex", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

stream := client.CumSum(ctx)
for _, el := range expectedHeaderValues {
stream.RequestHeader().Add(clientHeader, el)
}
// Send all, close request, then receive all.
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 10}))
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 20}))
assert.Nil(t, stream.CloseRequest())

resp1, err := stream.Receive()
assert.Nil(t, err)
assert.Equal(t, resp1.GetSum(), int64(10))

resp2, err := stream.Receive()
assert.Nil(t, err)
assert.Equal(t, resp2.GetSum(), int64(30))

_, err = stream.Receive()
assert.ErrorIs(t, err, io.EOF)
assert.Nil(t, stream.CloseResponse())
})

t.Run("cumsum_error", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

stream := client.CumSum(ctx)
// Don't set required headers -- server should return InvalidArgument.
if err := stream.Send(&pingv1.CumSumRequest{Number: 42}); err != nil {
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
}
_, err := stream.Receive()
assert.NotNil(t, err)
t.Log("cumsum_error", err)
assert.Equal(t, connect.CodeOf(err), connect.CodeInvalidArgument)
assert.True(t, connect.IsWireError(err))
})
}

func testClientDeadlineBruteForceLoop(
t *testing.T,
duration time.Duration,
Expand Down
178 changes: 82 additions & 96 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func TestCallInfo(t *testing.T) {
assertResponseHeadersAndTrailers(t, wrapper)
})
t.Run("bidi_stream", func(t *testing.T) {
testBidiStreamGenerics(t, client, true)
testBidiStreamGenerics(t, client)
})
t.Run("bidi_stream_simple_server", func(t *testing.T) {
mux := http.NewServeMux()
Expand All @@ -335,7 +335,7 @@ func TestCallInfo(t *testing.T) {
))
server := memhttptest.NewServer(t, mux)
genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL())
testBidiStreamGenerics(t, genericsClient, true)
testBidiStreamGenerics(t, genericsClient)
})
t.Run("bidi_stream_no_callinfo", func(t *testing.T) {
send := []int64{3, 5, 1}
Expand Down Expand Up @@ -503,16 +503,12 @@ func TestServer(t *testing.T) {
assert.Nil(t, stream.Close())
})
}
testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper
testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper
t.Run("cumsum", func(t *testing.T) {
testBidiStreamGenerics(t, client, expectSuccess)
testBidiStreamGenerics(t, client)
})
t.Run("cumsum_error", func(t *testing.T) {
stream := client.CumSum(t.Context())
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
if err := stream.Send(&pingv1.CumSumRequest{Number: 42}); err != nil {
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
Expand All @@ -528,10 +524,6 @@ func TestServer(t *testing.T) {
for _, el := range expectedHeaderValues {
stream.RequestHeader().Add(clientHeader, el)
}
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
// Deliberately closing with calling Send to test the behavior of Receive.
// This test case is based on the grpc interop tests.
assert.Nil(t, stream.CloseRequest())
Expand All @@ -547,11 +539,6 @@ func TestServer(t *testing.T) {
for _, el := range expectedHeaderValues {
stream.RequestHeader().Add(clientHeader, el)
}
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
cancel()
return
}
var got []int64
expect := []int64{42}
if err := stream.Send(&pingv1.CumSumRequest{Number: 42}); err != nil {
Expand All @@ -571,11 +558,6 @@ func TestServer(t *testing.T) {
t.Run("cumsum_cancel_before_send", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
stream := client.CumSum(ctx)
if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
cancel()
return
}
for _, el := range expectedHeaderValues {
stream.RequestHeader().Add(clientHeader, el)
}
Expand Down Expand Up @@ -646,14 +628,14 @@ func TestServer(t *testing.T) {
assertIsHTTPMiddlewareError(t, stream.Err())
})
}
testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper
testMatrix := func(t *testing.T, client *http.Client, url string) { //nolint:thelper
run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
client := pingv1connect.NewPingServiceClient(client, url, opts...)
testPing(t, client)
testSum(t, client)
testCountUp(t, client)
testCumSum(t, client, bidi)
testCumSum(t, client)
testErrors(t, client)
}
t.Run("connect", func(t *testing.T) {
Expand Down Expand Up @@ -759,13 +741,13 @@ func TestServer(t *testing.T) {
t.Parallel()
server := memhttptest.NewServer(t, mux)
client := &http.Client{Transport: server.TransportHTTP1()}
testMatrix(t, client, server.URL(), false /* bidi */)
testMatrix(t, client, server.URL())
})
t.Run("http2", func(t *testing.T) {
t.Parallel()
server := memhttptest.NewServer(t, mux)
client := server.Client()
testMatrix(t, client, server.URL(), true /* bidi */)
testMatrix(t, client, server.URL())
})
}

Expand Down Expand Up @@ -1220,35 +1202,6 @@ func TestUnavailableIfHostInvalid(t *testing.T) {
assert.Equal(t, connect.CodeOf(err), connect.CodeUnavailable)
}

func TestBidiRequiresHTTP2(t *testing.T) {
t.Parallel()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/connect+proto")
_, err := io.WriteString(w, "hello world")
assert.Nil(t, err)
})
server := memhttptest.NewServer(t, handler)
client := pingv1connect.NewPingServiceClient(
&http.Client{Transport: server.TransportHTTP1()},
server.URL(),
)
stream := client.CumSum(t.Context())
// Stream creates an async request, can error on Send or Receive.
if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(t, err, io.EOF)
}
assert.Nil(t, stream.CloseRequest())
_, err := stream.Receive()
assert.NotNil(t, err)
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
assert.Equal(t, connectErr.Code(), connect.CodeUnimplemented)
assert.True(
t,
strings.HasSuffix(connectErr.Message(), ": bidi streams require at least HTTP/2"),
)
}

func TestCompressMinBytesClient(t *testing.T) {
t.Parallel()
assertContentType := func(tb testing.TB, text, expect string) {
Expand Down Expand Up @@ -2443,28 +2396,82 @@ func TestWebXUserAgent(t *testing.T) {

func TestBidiOverHTTP1(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := memhttptest.NewServer(t, mux)

// Clients expecting a full-duplex connection that end up with a simplex
// HTTP/1.1 connection shouldn't hang. Instead, the server should close the
// TCP connection.
client := pingv1connect.NewPingServiceClient(
&http.Client{Transport: server.TransportHTTP1()},
server.URL(),
)
stream := client.CumSum(t.Context())
// Stream creates an async request, can error on Send or Receive.
if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil {
t.Run("full_duplex_with_enable_full_duplex", func(t *testing.T) {
t.Parallel()
// Go's net/http supports EnableFullDuplex, so bidi works over HTTP/1.
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := memhttptest.NewServer(t, mux)
client := pingv1connect.NewPingServiceClient(
&http.Client{Transport: server.TransportHTTP1()},
server.URL(),
)
stream := client.CumSum(t.Context())
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 1}))

response, err := stream.Receive()
assert.Nil(t, err)
assert.NotNil(t, response)
assert.Equal(t, response.GetSum(), 1)

assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 2}))

response, err = stream.Receive()
assert.Nil(t, err)
assert.NotNil(t, response)
assert.Equal(t, response.GetSum(), 3)

assert.Nil(t, stream.CloseRequest())
_, err = stream.Receive()
assert.ErrorIs(t, err, io.EOF)
assert.Nil(t, stream.CloseResponse())
})

t.Run("rejected_without_enable_full_duplex", func(t *testing.T) {
t.Parallel()
// When the ResponseWriter does not support EnableFullDuplex,
// the server rejects the bidi request with 505.
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
// Wrap the handler to strip the ResponseController's Unwrap method,
// preventing EnableFullDuplex from succeeding.
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mux.ServeHTTP(&noFullDuplexResponseWriter{ResponseWriter: w}, r)
})
server := memhttptest.NewServer(t, handler)
client := pingv1connect.NewPingServiceClient(
&http.Client{Transport: server.TransportHTTP1()},
server.URL(),
)
stream := client.CumSum(t.Context())
if err := stream.Send(&pingv1.CumSumRequest{Number: 1}); err != nil {
assert.ErrorIs(t, err, io.EOF)
}
assert.Nil(t, stream.CloseRequest())
_, err := stream.Receive()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
assert.True(
t,
strings.Contains(err.Error(), "HTTP status 505"),
assert.Sprintf("expected 505 error, got: %v", err),
)
assert.Nil(t, stream.CloseResponse())
})
}

// noFullDuplexResponseWriter wraps an http.ResponseWriter without exposing
// the Unwrap method. This prevents http.ResponseController.EnableFullDuplex
// from reaching the underlying implementation.
type noFullDuplexResponseWriter struct {
http.ResponseWriter
}

func (w *noFullDuplexResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
_, err := stream.Receive()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
assert.Equal(t, err.Error(), "unknown: HTTP status 505 HTTP Version Not Supported")
assert.Nil(t, stream.CloseRequest())
assert.Nil(t, stream.CloseResponse())
}

func TestHandlerReturnsNilResponse(t *testing.T) {
Expand Down Expand Up @@ -3915,23 +3922,6 @@ func handleCumSum(
}
}

func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
tb.Helper()
if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(tb, err, io.EOF)
assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown)
}
assert.Nil(tb, stream.CloseRequest())
_, err := stream.Receive()
assert.NotNil(tb, err) // should be 505
assert.True(
tb,
strings.Contains(err.Error(), "HTTP status 505"),
assert.Sprintf("expected 505, got %v", err),
)
assert.Nil(tb, stream.CloseResponse())
}

func testUnaryGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper
num := int64(42)
request := connect.NewRequest(&pingv1.PingRequest{Number: num})
Expand Down Expand Up @@ -4198,7 +4188,7 @@ func testBidiStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceCl
assertResponseHeadersAndTrailers(t, stream)
}

func testBidiStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper
func testBidiStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper
send := []int64{3, 5, 1}
expect := []int64{3, 8, 9}
var got []int64
Expand All @@ -4209,10 +4199,6 @@ func testBidiStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient
stream := client.CumSum(ctx)
stream.RequestHeader().Add(clientHeader, "bar")

if !expectSuccess { // server doesn't support HTTP/2
failNoHTTP2(t, stream)
return
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
Expand Down
Loading
Loading