diff --git a/go.mod b/go.mod index 9627c2c0..9c4ca20a 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/itchyny/gojq v0.12.19 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 + github.com/spf13/pflag v1.0.9 github.com/stretchr/testify v1.11.1 github.com/tetratelabs/wazero v1.11.0 go.opentelemetry.io/otel v1.43.0 @@ -34,7 +35,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/segmentio/encoding v0.5.4 // indirect - github.com/spf13/pflag v1.0.9 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect diff --git a/internal/cmd/flags_tracing.go b/internal/cmd/flags_tracing.go index bd90b8db..7caf3ccb 100644 --- a/internal/cmd/flags_tracing.go +++ b/internal/cmd/flags_tracing.go @@ -3,8 +3,6 @@ package cmd // Tracing-related flags for OpenTelemetry OTLP trace export. import ( - "github.com/github/gh-aw-mcpg/internal/config" - "github.com/github/gh-aw-mcpg/internal/envutil" "github.com/spf13/cobra" ) @@ -17,11 +15,9 @@ var ( func init() { RegisterFlag(func(cmd *cobra.Command) { - cmd.Flags().StringVar(&otlpEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""), - "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Defaults from OTEL_EXPORTER_OTLP_ENDPOINT when set. Tracing is disabled when empty.") - cmd.Flags().StringVar(&otlpServiceName, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName), - "Service name reported in traces. Defaults from OTEL_SERVICE_NAME when set.") - cmd.Flags().Float64Var(&otlpSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate, + registerTracingFlags(cmd.Flags(), &otlpEndpoint, &otlpServiceName, &otlpSampleRate, + "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Defaults from OTEL_EXPORTER_OTLP_ENDPOINT when set. Tracing is disabled when empty.", + "Service name reported in traces. Defaults from OTEL_SERVICE_NAME when set.", "Fraction of traces to sample and export (0.0–1.0). Default 1.0 samples everything.") }) } diff --git a/internal/cmd/proxy.go b/internal/cmd/proxy.go index 6c836ac1..016a38c6 100644 --- a/internal/cmd/proxy.go +++ b/internal/cmd/proxy.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "crypto/tls" "fmt" "log" @@ -12,14 +11,12 @@ import ( "path/filepath" "strings" "syscall" - "time" "github.com/github/gh-aw-mcpg/internal/config" "github.com/github/gh-aw-mcpg/internal/difc" "github.com/github/gh-aw-mcpg/internal/envutil" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/proxy" - "github.com/github/gh-aw-mcpg/internal/tracing" "github.com/spf13/cobra" ) @@ -121,11 +118,9 @@ Local usage: cmd.Flags().StringVar(&proxyTLSDir, "tls-dir", "", "Directory for TLS certificates (default: /proxy-tls)") cmd.Flags().StringSliceVar(&proxyTrustedBots, "trusted-bots", nil, "Additional trusted bot usernames (comma-separated, extends built-in list)") cmd.Flags().StringSliceVar(&proxyTrustedUsers, "trusted-users", nil, "User logins that receive approved integrity (comma-separated)") - cmd.Flags().StringVar(&proxyOTLPEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""), - "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Tracing is disabled when empty.") - cmd.Flags().StringVar(&proxyOTLPService, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName), - "Service name reported in traces.") - cmd.Flags().Float64Var(&proxyOTLPSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate, + registerTracingFlags(cmd.Flags(), &proxyOTLPEndpoint, &proxyOTLPService, &proxyOTLPSampleRate, + "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Tracing is disabled when empty.", + "Service name reported in traces.", "Fraction of traces to sample and export (0.0–1.0).") // Only require --guard-wasm when no baked-in guard is available @@ -165,17 +160,18 @@ func runProxy(cmd *cobra.Command, args []string) error { SampleRate: &proxyOTLPSampleRate, } } - tracingProvider, err := tracing.InitProvider(ctx, tracingCfg) - if err != nil { - log.Printf("Warning: failed to initialize tracing provider: %v", err) - tracingProvider, _ = tracing.InitProvider(ctx, nil) - } + tracingProvider := initTracingProviderWithFallback( + ctx, + tracingCfg, + "failed to initialize tracing provider: %v", + func(format string, args ...any) { + log.Printf("Warning: "+format, args...) + }, + ) defer func() { - shutdownCtx, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelTracing() - if err := tracingProvider.Shutdown(shutdownCtx); err != nil { - log.Printf("Warning: tracing provider shutdown error: %v", err) - } + shutdownTracingProviderWithTimeout(tracingProvider, func(format string, args ...any) { + log.Printf("Warning: "+format, args...) + }) }() if tracingCfg != nil { log.Printf("OpenTelemetry tracing enabled for proxy: endpoint=%s, service=%s", proxyOTLPEndpoint, proxyOTLPService) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 96557761..d1c95630 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -311,18 +311,18 @@ func run(cmd *cobra.Command, args []string) error { if cfg.Gateway != nil { tracingCfg = cfg.Gateway.Tracing } - tracingProvider, err := tracing.InitProvider(ctx, tracingCfg) - if err != nil { - logger.StartupWarn("Failed to initialize tracing provider: %v", err) - // Non-fatal: continue without tracing - tracingProvider, _ = tracing.InitProvider(ctx, nil) - } + tracingProvider := initTracingProviderWithFallback( + ctx, + tracingCfg, + "Failed to initialize tracing provider: %v", + func(format string, args ...any) { + logger.StartupWarn(format, args...) + }, + ) defer func() { - shutdownCtxTracing, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelTracing() - if err := tracingProvider.Shutdown(shutdownCtxTracing); err != nil { - log.Printf("Warning: tracing provider shutdown error: %v", err) - } + shutdownTracingProviderWithTimeout(tracingProvider, func(format string, args ...any) { + log.Printf("Warning: "+format, args...) + }) }() // Apply W3C parent context from configured traceId/spanId (spec §4.1.3.6). diff --git a/internal/cmd/tracing_helpers.go b/internal/cmd/tracing_helpers.go new file mode 100644 index 00000000..1c046635 --- /dev/null +++ b/internal/cmd/tracing_helpers.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "time" + + "github.com/github/gh-aw-mcpg/internal/config" + "github.com/github/gh-aw-mcpg/internal/envutil" + "github.com/github/gh-aw-mcpg/internal/tracing" + "github.com/spf13/pflag" +) + +func registerTracingFlags(flags *pflag.FlagSet, endpoint *string, serviceName *string, sampleRate *float64, endpointUsage string, serviceUsage string, sampleUsage string) { + flags.StringVar(endpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""), + endpointUsage) + flags.StringVar(serviceName, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName), + serviceUsage) + flags.Float64Var(sampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate, + sampleUsage) +} + +func initTracingProviderWithFallback( + ctx context.Context, + tracingCfg *config.TracingConfig, + initWarningFormat string, + warnf func(format string, args ...any), +) *tracing.Provider { + tracingProvider, err := tracing.InitProvider(ctx, tracingCfg) + if err != nil { + warnf(initWarningFormat, err) + tracingProvider, _ = tracing.InitProvider(ctx, nil) + } + + return tracingProvider +} + +func shutdownTracingProviderWithTimeout(tracingProvider *tracing.Provider, warnf func(format string, args ...any)) { + shutdownCtxTracing, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelTracing() + + if err := tracingProvider.Shutdown(shutdownCtxTracing); err != nil { + warnf("tracing provider shutdown error: %v", err) + } +} diff --git a/internal/cmd/tracing_helpers_test.go b/internal/cmd/tracing_helpers_test.go new file mode 100644 index 00000000..da4d9bef --- /dev/null +++ b/internal/cmd/tracing_helpers_test.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/github/gh-aw-mcpg/internal/config" +) + +func TestRegisterTracingFlags_DefaultsFromEnv(t *testing.T) { + t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://collector:4318") + t.Setenv("OTEL_SERVICE_NAME", "test-service") + + cmd := &cobra.Command{Use: "test"} + + var endpoint string + var service string + var sampleRate float64 + + registerTracingFlags( + cmd.Flags(), + &endpoint, + &service, + &sampleRate, + "endpoint help", + "service help", + "sample help", + ) + + actualEndpoint, err := cmd.Flags().GetString("otlp-endpoint") + require.NoError(t, err) + assert.Equal(t, "http://collector:4318", actualEndpoint) + + actualService, err := cmd.Flags().GetString("otlp-service-name") + require.NoError(t, err) + assert.Equal(t, "test-service", actualService) + + actualSampleRate, err := cmd.Flags().GetFloat64("otlp-sample-rate") + require.NoError(t, err) + assert.Equal(t, config.DefaultTracingSampleRate, actualSampleRate) + + err = cmd.ParseFlags([]string{ + "--otlp-endpoint=http://override:4318", + "--otlp-service-name=override-service", + "--otlp-sample-rate=0.25", + }) + require.NoError(t, err) + assert.Equal(t, "http://override:4318", endpoint) + assert.Equal(t, "override-service", service) + assert.Equal(t, 0.25, sampleRate) +} diff --git a/internal/server/http_server.go b/internal/server/http_server.go new file mode 100644 index 00000000..6b558ff5 --- /dev/null +++ b/internal/server/http_server.go @@ -0,0 +1,10 @@ +package server + +import "net/http" + +func newHTTPServer(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + } +} diff --git a/internal/server/http_server_test.go b/internal/server/http_server_test.go new file mode 100644 index 00000000..88ad505e --- /dev/null +++ b/internal/server/http_server_test.go @@ -0,0 +1,18 @@ +package server + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewHTTPServer(t *testing.T) { + handler := http.NewServeMux() + + server := newHTTPServer("127.0.0.1:1234", handler) + require.NotNil(t, server) + assert.Equal(t, "127.0.0.1:1234", server.Addr) + assert.Same(t, handler, server.Handler) +} diff --git a/internal/server/routed.go b/internal/server/routed.go index 7549b5d2..96cc4340 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -174,10 +174,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap log.Printf("Registered route: %s", route) } - return &http.Server{ - Addr: addr, - Handler: mux, - } + return newHTTPServer(addr, mux) } // createFilteredServer creates an MCP server that only exposes tools for a specific backend diff --git a/internal/server/transport.go b/internal/server/transport.go index 608e6d4c..5549b946 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -48,8 +48,5 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st mux.Handle("/mcp/", finalHandler) mux.Handle("/mcp", finalHandler) - return &http.Server{ - Addr: addr, - Handler: mux, - } + return newHTTPServer(addr, mux) }