Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions auth/grpctransport/grpctransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"strconv"
"strings"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
Expand Down Expand Up @@ -368,7 +370,7 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
// Add tracing, but before the other options, so that clients can override the
// gRPC stats handler.
// This assumes that gRPC options are processed in order, left to right.
grpcOpts = addOpenTelemetryStatsHandler(grpcOpts, opts)
grpcOpts = addOpenTelemetryStatsHandler(grpcOpts, opts, transportCreds.Endpoint)
grpcOpts = append(grpcOpts, opts.GRPCDialOpts...)

return grpc.DialContext(ctx, transportCreds.Endpoint, grpcOpts...)
Expand Down Expand Up @@ -456,10 +458,14 @@ func (c *grpcCredentialsProvider) RequireTransportSecurity() bool {
return c.secure
}

func addOpenTelemetryStatsHandler(dialOpts []grpc.DialOption, opts *Options) []grpc.DialOption {
func addOpenTelemetryStatsHandler(dialOpts []grpc.DialOption, opts *Options, endpoint string) []grpc.DialOption {
if opts.DisableTelemetry {
return dialOpts
}
if gax.IsFeatureEnabled("METRICS") {
host, port := extractHostPort(endpoint)
dialOpts = append(dialOpts, grpc.WithChainUnaryInterceptor(openTelemetryUnaryClientInterceptor(host, port)))
}
if !gax.IsFeatureEnabled("TRACING") && !gax.IsFeatureEnabled("LOGGING") {
return append(dialOpts, grpc.WithStatsHandler(otelgrpc.NewClientHandler()))
}
Expand All @@ -475,15 +481,56 @@ func addOpenTelemetryStatsHandler(dialOpts []grpc.DialOption, opts *Options) []g
scopedLogger = opts.Logger.With(staticLogAttrs...)
}
}
otelOpts := []otelgrpc.Option{
otelgrpc.WithSpanAttributes(staticAttrs...),
var otelOpts []otelgrpc.Option
if gax.IsFeatureEnabled("TRACING") {
otelOpts = append(otelOpts, otelgrpc.WithSpanAttributes(staticAttrs...))
}
return append(dialOpts, grpc.WithStatsHandler(&otelHandler{
Handler: otelgrpc.NewClientHandler(otelOpts...),
logger: scopedLogger,
}))
}

// Extract the host and port from a target address
func extractHostPort(target string) (string, int) {
if idx := strings.Index(target, "://"); idx != -1 {
target = target[idx+3:]
// Ensure any leading slashes from the scheme suffix are stripped
for strings.HasPrefix(target, "/") {
target = target[1:]
}
}
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return target, 0
}
port, err := strconv.Atoi(portStr)
if err != nil {
return host, 0
}
return host, port
}

// openTelemetryUnaryClientInterceptor returns an interceptor that populates
// TransportTelemetryData with the server peer address.
func openTelemetryUnaryClientInterceptor(host string, port int) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
transportData := gax.ExtractTransportTelemetry(ctx)
if transportData != nil {
if host != "" {
transportData.SetServerAddress(host)
}
if port != 0 {
transportData.SetServerPort(port)
}
}

err := invoker(ctx, method, req, reply, cc, opts...)

return err
}
}

// otelHandler is a wrapper around the OpenTelemetry gRPC client handler that
// adds custom Google Cloud-specific attributes to spans and metrics.
type otelHandler struct {
Expand Down
112 changes: 99 additions & 13 deletions auth/grpctransport/grpctransport_otel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ func TestHandleRPC_ActionableErrors(t *testing.T) {
}
}

func TestDial_TracingAndLogging_Combinations(t *testing.T) {
func TestDial_Telemetry_Combinations(t *testing.T) {
// Ensure any lingering HTTP/2 connections are closed to avoid goroutine leaks.
defer http.DefaultTransport.(*http.Transport).CloseIdleConnections()

Expand All @@ -867,36 +867,82 @@ func TestDial_TracingAndLogging_Combinations(t *testing.T) {
name string
logging bool
tracing bool
metrics bool
wantLog bool
wantTracingAttrs bool
wantMetricsAttrs bool
}{
{
name: "both disabled",
name: "all disabled",
logging: false,
tracing: false,
metrics: false,
wantLog: false,
wantTracingAttrs: false,
wantMetricsAttrs: false,
},
{
name: "tracing enabled, logging disabled",
name: "tracing enabled",
logging: false,
tracing: true,
metrics: false,
wantLog: false,
wantTracingAttrs: true,
wantMetricsAttrs: false,
},
{
name: "tracing disabled, logging enabled",
name: "logging enabled",
logging: true,
tracing: false,
metrics: false,
wantLog: true,
wantTracingAttrs: false,
wantMetricsAttrs: false,
},
{
name: "metrics enabled",
logging: false,
tracing: false,
metrics: true,
wantLog: false,
wantTracingAttrs: false,
wantMetricsAttrs: true,
},
{
name: "tracing and logging enabled",
logging: true,
tracing: true,
metrics: false,
wantLog: true,
wantTracingAttrs: true,
wantMetricsAttrs: false,
},
{
name: "tracing and metrics enabled",
logging: false,
tracing: true,
metrics: true,
wantLog: false,
wantTracingAttrs: true,
wantMetricsAttrs: true,
},
{
name: "logging and metrics enabled",
logging: true,
tracing: false,
metrics: true,
wantLog: true,
wantTracingAttrs: false,
wantMetricsAttrs: true,
},
{
name: "both enabled",
name: "all enabled",
logging: true,
tracing: true,
metrics: true,
wantLog: true,
wantTracingAttrs: true,
wantMetricsAttrs: true,
},
}

Expand All @@ -916,6 +962,11 @@ func TestDial_TracingAndLogging_Combinations(t *testing.T) {
} else {
t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING", "false")
}
if tt.metrics {
t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_METRICS", "true")
} else {
t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_METRICS", "false")
}

l, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand Down Expand Up @@ -951,8 +1002,11 @@ func TestDial_TracingAndLogging_Combinations(t *testing.T) {
}
defer pool.Close()

data := &gax.TransportTelemetryData{}
ctx := gax.InjectTransportTelemetry(context.Background(), data)

client := echo.NewEchoerClient(pool)
_, _ = client.Echo(context.Background(), &echo.EchoRequest{Message: "hello"})
_, _ = client.Echo(ctx, &echo.EchoRequest{Message: "hello"})

logOutput := logBuf.String()
hasLog := strings.TrimSpace(logOutput) != ""
Expand All @@ -961,16 +1015,19 @@ func TestDial_TracingAndLogging_Combinations(t *testing.T) {
t.Errorf("got log: %v, want: %v\noutput: %s", hasLog, tt.wantLog, logOutput)
}

spans := exporter.GetSpans()
if len(spans) != 1 {
t.Fatalf("len(spans) = %d, want 1", len(spans))
hasMetricsAttrs := data.ServerAddress() != ""
if hasMetricsAttrs != tt.wantMetricsAttrs {
t.Errorf("got metrics attrs: %v, want: %v", hasMetricsAttrs, tt.wantMetricsAttrs)
}

spans := exporter.GetSpans()
hasTracingAttrs := false
for _, attr := range spans[0].Attributes {
if attr.Key == "gcp.client.version" && attr.Value.AsString() == "1.2.3" {
hasTracingAttrs = true
break
for _, span := range spans {
for _, attr := range span.Attributes {
if attr.Key == "gcp.client.version" && attr.Value.AsString() == "1.2.3" {
hasTracingAttrs = true
break
}
}
}

Expand All @@ -994,3 +1051,32 @@ func (m *mockStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo)
}

func (m *mockStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {}

func TestExtractHostPort(t *testing.T) {
tests := []struct {
target string
wantHost string
wantPort int
}{
{"localhost:8080", "localhost", 8080},
{"[::1]:443", "::1", 443},
{"google.com", "google.com", 0},
{"dns:///localhost:8080", "localhost", 8080},
{"dns:///google.com:443", "google.com", 443},
{"xds:///my-service:80", "my-service", 80},
{"dns:///[::1]:8080", "::1", 8080},
{"google.com:foo", "google.com", 0},
}

for _, tt := range tests {
t.Run(tt.target, func(t *testing.T) {
gotHost, gotPort := extractHostPort(tt.target)
if gotHost != tt.wantHost {
t.Errorf("extractHostPort(%q) host = %q, want %q", tt.target, gotHost, tt.wantHost)
}
if gotPort != tt.wantPort {
t.Errorf("extractHostPort(%q) port = %v, want %v", tt.target, gotPort, tt.wantPort)
}
})
}
}
Loading
Loading