Skip to content

Commit

Permalink
Add proxy dialer support
Browse files Browse the repository at this point in the history
  • Loading branch information
nmische committed Sep 1, 2024
1 parent 46c38b3 commit 8f9ab68
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmd/grpcurl/grpcurl.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ func main() {
grpcurlUA = *userAgent + " " + grpcurlUA
}
opts = append(opts, grpc.WithUserAgent(grpcurlUA))
grpcurl.GrpcurlUA = grpcurlUA

blockingDialTiming := dialTiming.Child("BlockingDial")
defer blockingDialTiming.Done()
Expand Down
4 changes: 3 additions & 1 deletion grpcurl.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

var GrpcurlUA string

// ListServices uses the given descriptor source to return a sorted list of fully-qualified
// service names.
func ListServices(source DescriptorSource) ([]string, error) {
Expand Down Expand Up @@ -653,7 +655,7 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
// handshake). And that would mean that the library would send the
// wrong ":scheme" metaheader to servers: it would send "http" instead
// of "https" because it is unaware that TLS is actually in use.
conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
conn, err := proxyDial(ctx, address, GrpcurlUA)
if err != nil {
writeResult(err)
}
Expand Down
156 changes: 156 additions & 0 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// NOTE: This source file contains the internal grpc-go proxy implementation
// found in google.golang.org/grpc/internal/transport, with minor
// modifications for use in grpcurl. Below is the original license:

/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcurl

import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
// "google.golang.org/grpc/internal"
)

const proxyAuthHeaderKey = "Proxy-Authorization"

var (
// The following variable will be overwritten in the tests.
httpProxyFromEnvironment = http.ProxyFromEnvironment
)

func mapAddress(address string) (*url.URL, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
Host: address,
},
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return nil, err
}
return url, nil
}

// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
// It's possible that this reader reads more than what's need for the response and stores
// those bytes in the buffer.
// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
// bytes in the buffer.
type bufConn struct {
net.Conn
r io.Reader
}

func (c *bufConn) Read(b []byte) (int, error) {
return c.r.Read(b)
}

func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}

func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
}
}()

req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: backendAddr},
Header: map[string][]string{"User-Agent": {grpcUA}},
}
if t := proxyURL.User; t != nil {
u := t.Username()
p, _ := t.Password()
req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
}

if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
}

r := bufio.NewReader(conn)
resp, err := http.ReadResponse(r, req)
if err != nil {
return nil, fmt.Errorf("reading server HTTP response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
dump, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status)
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
}

// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
// is necessary, dials, does the HTTP CONNECT handshake, and returns the
// connection.
func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) {
newAddr := addr
proxyURL, err := mapAddress(addr)
if err != nil {
return nil, err
}
if proxyURL != nil {
newAddr = proxyURL.Host
}

// NOTE: Use net.Dialer to avoid dependency on grpc-go's internal package
// conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr)

conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", newAddr)
if err != nil {
return nil, err
}
if proxyURL == nil {
// proxy is disabled if proxyURL is nil.
return conn, err
}
return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA)
}

func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return fmt.Errorf("failed to write the HTTP request: %v", err)
}
return nil
}

0 comments on commit 8f9ab68

Please sign in to comment.