Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
require (
github.com/itchyny/gojq v0.12.19
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
github.com/spf13/pflag v1.0.9
github.com/stretchr/testify v1.11.1
github.com/tetratelabs/wazero v1.11.0
go.opentelemetry.io/otel v1.43.0
Expand All @@ -34,7 +35,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
Expand Down
10 changes: 3 additions & 7 deletions internal/cmd/flags_tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package cmd
// Tracing-related flags for OpenTelemetry OTLP trace export.

import (
"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/spf13/cobra"
)

Expand All @@ -17,11 +15,9 @@ var (

func init() {
RegisterFlag(func(cmd *cobra.Command) {
cmd.Flags().StringVar(&otlpEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
"OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Defaults from OTEL_EXPORTER_OTLP_ENDPOINT when set. Tracing is disabled when empty.")
cmd.Flags().StringVar(&otlpServiceName, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName),
"Service name reported in traces. Defaults from OTEL_SERVICE_NAME when set.")
cmd.Flags().Float64Var(&otlpSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate,
registerTracingFlags(cmd.Flags(), &otlpEndpoint, &otlpServiceName, &otlpSampleRate,
"OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Defaults from OTEL_EXPORTER_OTLP_ENDPOINT when set. Tracing is disabled when empty.",
"Service name reported in traces. Defaults from OTEL_SERVICE_NAME when set.",
"Fraction of traces to sample and export (0.0–1.0). Default 1.0 samples everything.")
})
}
27 changes: 9 additions & 18 deletions internal/cmd/proxy.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"context"
"crypto/tls"
"fmt"
"log"
Expand All @@ -12,14 +11,12 @@ import (
"path/filepath"
"strings"
"syscall"
"time"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/proxy"
"github.com/github/gh-aw-mcpg/internal/tracing"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -121,11 +118,9 @@ Local usage:
cmd.Flags().StringVar(&proxyTLSDir, "tls-dir", "", "Directory for TLS certificates (default: <log-dir>/proxy-tls)")
cmd.Flags().StringSliceVar(&proxyTrustedBots, "trusted-bots", nil, "Additional trusted bot usernames (comma-separated, extends built-in list)")
cmd.Flags().StringSliceVar(&proxyTrustedUsers, "trusted-users", nil, "User logins that receive approved integrity (comma-separated)")
cmd.Flags().StringVar(&proxyOTLPEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
"OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Tracing is disabled when empty.")
cmd.Flags().StringVar(&proxyOTLPService, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName),
"Service name reported in traces.")
cmd.Flags().Float64Var(&proxyOTLPSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate,
registerTracingFlags(cmd.Flags(), &proxyOTLPEndpoint, &proxyOTLPService, &proxyOTLPSampleRate,
"OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Tracing is disabled when empty.",
"Service name reported in traces.",
"Fraction of traces to sample and export (0.0–1.0).")

// Only require --guard-wasm when no baked-in guard is available
Expand Down Expand Up @@ -165,17 +160,13 @@ func runProxy(cmd *cobra.Command, args []string) error {
SampleRate: &proxyOTLPSampleRate,
}
}
tracingProvider, err := tracing.InitProvider(ctx, tracingCfg)
if err != nil {
log.Printf("Warning: failed to initialize tracing provider: %v", err)
tracingProvider, _ = tracing.InitProvider(ctx, nil)
}
tracingProvider := initTracingProviderWithFallback(ctx, tracingCfg, func(format string, args ...any) {
log.Printf("Warning: "+format, args...)
})
defer func() {
shutdownCtx, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelTracing()
if err := tracingProvider.Shutdown(shutdownCtx); err != nil {
log.Printf("Warning: tracing provider shutdown error: %v", err)
}
shutdownTracingProviderWithTimeout(tracingProvider, func(format string, args ...any) {
log.Printf("Warning: "+format, args...)
})
}()
if tracingCfg != nil {
log.Printf("OpenTelemetry tracing enabled for proxy: endpoint=%s, service=%s", proxyOTLPEndpoint, proxyOTLPService)
Expand Down
17 changes: 6 additions & 11 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,18 +311,13 @@ func run(cmd *cobra.Command, args []string) error {
if cfg.Gateway != nil {
tracingCfg = cfg.Gateway.Tracing
}
tracingProvider, err := tracing.InitProvider(ctx, tracingCfg)
if err != nil {
logger.StartupWarn("Failed to initialize tracing provider: %v", err)
// Non-fatal: continue without tracing
tracingProvider, _ = tracing.InitProvider(ctx, nil)
}
tracingProvider := initTracingProviderWithFallback(ctx, tracingCfg, func(format string, args ...any) {
logger.StartupWarn(format, args...)
})
defer func() {
shutdownCtxTracing, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelTracing()
if err := tracingProvider.Shutdown(shutdownCtxTracing); err != nil {
log.Printf("Warning: tracing provider shutdown error: %v", err)
}
shutdownTracingProviderWithTimeout(tracingProvider, func(format string, args ...any) {
log.Printf("Warning: "+format, args...)
})
}()

// Apply W3C parent context from configured traceId/spanId (spec §4.1.3.6).
Expand Down
39 changes: 39 additions & 0 deletions internal/cmd/tracing_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package cmd

import (
"context"
"time"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/github/gh-aw-mcpg/internal/tracing"
"github.com/spf13/pflag"
)

func registerTracingFlags(flags *pflag.FlagSet, endpoint *string, serviceName *string, sampleRate *float64, endpointUsage string, serviceUsage string, sampleUsage string) {
flags.StringVar(endpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
endpointUsage)
flags.StringVar(serviceName, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName),
serviceUsage)
flags.Float64Var(sampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate,
sampleUsage)
}

func initTracingProviderWithFallback(ctx context.Context, tracingCfg *config.TracingConfig, warnf func(format string, args ...any)) *tracing.Provider {
tracingProvider, err := tracing.InitProvider(ctx, tracingCfg)
if err != nil {
warnf("failed to initialize tracing provider: %v", err)
tracingProvider, _ = tracing.InitProvider(ctx, nil)
}

return tracingProvider
}

func shutdownTracingProviderWithTimeout(tracingProvider *tracing.Provider, warnf func(format string, args ...any)) {
shutdownCtxTracing, cancelTracing := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelTracing()

if err := tracingProvider.Shutdown(shutdownCtxTracing); err != nil {
warnf("tracing provider shutdown error: %v", err)
}
}
54 changes: 54 additions & 0 deletions internal/cmd/tracing_helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package cmd

import (
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/github/gh-aw-mcpg/internal/config"
)

func TestRegisterTracingFlags_DefaultsFromEnv(t *testing.T) {
t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://collector:4318")
t.Setenv("OTEL_SERVICE_NAME", "test-service")

cmd := &cobra.Command{Use: "test"}

var endpoint string
var service string
var sampleRate float64

registerTracingFlags(
cmd.Flags(),
&endpoint,
&service,
&sampleRate,
"endpoint help",
"service help",
"sample help",
)

actualEndpoint, err := cmd.Flags().GetString("otlp-endpoint")
require.NoError(t, err)
assert.Equal(t, "http://collector:4318", actualEndpoint)

actualService, err := cmd.Flags().GetString("otlp-service-name")
require.NoError(t, err)
assert.Equal(t, "test-service", actualService)

actualSampleRate, err := cmd.Flags().GetFloat64("otlp-sample-rate")
require.NoError(t, err)
assert.Equal(t, config.DefaultTracingSampleRate, actualSampleRate)

err = cmd.ParseFlags([]string{
"--otlp-endpoint=http://override:4318",
"--otlp-service-name=override-service",
"--otlp-sample-rate=0.25",
})
require.NoError(t, err)
assert.Equal(t, "http://override:4318", endpoint)
assert.Equal(t, "override-service", service)
assert.Equal(t, 0.25, sampleRate)
}
10 changes: 10 additions & 0 deletions internal/server/http_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package server

import "net/http"

func newHTTPServer(addr string, handler http.Handler) *http.Server {
return &http.Server{
Addr: addr,
Handler: handler,
}
}
18 changes: 18 additions & 0 deletions internal/server/http_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package server

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewHTTPServer(t *testing.T) {
handler := http.NewServeMux()

server := newHTTPServer("127.0.0.1:1234", handler)
require.NotNil(t, server)
assert.Equal(t, "127.0.0.1:1234", server.Addr)
assert.Same(t, handler, server.Handler)
}
5 changes: 1 addition & 4 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
log.Printf("Registered route: %s", route)
}

return &http.Server{
Addr: addr,
Handler: mux,
}
return newHTTPServer(addr, mux)
}

// createFilteredServer creates an MCP server that only exposes tools for a specific backend
Expand Down
5 changes: 1 addition & 4 deletions internal/server/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,5 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st
mux.Handle("/mcp/", finalHandler)
mux.Handle("/mcp", finalHandler)

return &http.Server{
Addr: addr,
Handler: mux,
}
return newHTTPServer(addr, mux)
}
Loading