diff --git a/server/.golangci.yaml b/server/.golangci.yaml index 1bf195cb4f..e4fa017021 100644 --- a/server/.golangci.yaml +++ b/server/.golangci.yaml @@ -64,6 +64,21 @@ linters: linters: - exhaustruct - gosec + # The guardian package is the only package that is allowed to use + # http.Client directly since it is used to construct the safe HTTP clients + # that are used throughout the codebase. This is tracked by GG002. + - path: internal/guardian/policy\.go + linters: + - forbidigo + text: GG002 + # The gateway package contains an HTTP reverse proxy that needs finer + # grained control over the HTTP client it uses, so we allow direct use of + # http.Client there as well. It is manually decorating the client with + # SSRF protection from the guardian module and OTel instrumentation. + - path: internal/gateway/proxy\.go + linters: + - forbidigo + text: GG002 settings: gosec: excludes: @@ -82,10 +97,13 @@ linters: forbid: - pattern: ^os\.Getenv$ pkg: ^os$ - msg: "Direct environment variable access is forbidden because it makes testing harder. Access to environment variables is only allowed in the cmd/." + msg: "GG001: Direct environment variable access is forbidden because it makes testing harder. Access to environment variables is only allowed in the cmd/." + - pattern: ^http\.Client$ + pkg: ^net/http$ + msg: "GG002: Use `(*github.com/speakeasy-api/gram/server/internal/guardian.Policy).Client(...)` (or .PooledClient(...)) and the HTTPClient from the same package." - pattern: ^otel\.Tracer$ pkg: ^go\.opentelemetry\.io/otel$ - msg: "GG003: pass a trace.TracerProvider to functions instead of using otel.Tracer directly." + msg: "GG003: Pass a trace.TracerProvider to functions instead of using otel.Tracer directly." sloglint: no-mixed-args: true attr-only: true diff --git a/server/cmd/gram/deps.go b/server/cmd/gram/deps.go index df16880cb2..c0fca1443e 100644 --- a/server/cmd/gram/deps.go +++ b/server/cmd/gram/deps.go @@ -45,6 +45,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/externalmcp" "github.com/speakeasy-api/gram/server/internal/feature" "github.com/speakeasy-api/gram/server/internal/functions" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/inv" "github.com/speakeasy-api/gram/server/internal/must" "github.com/speakeasy-api/gram/server/internal/o11y" @@ -70,6 +71,27 @@ func loadConfigFromFile(c *cli.Context, flags []cli.Flag) error { return cfgLoader(c) } +func newGuardianPolicy(c *cli.Context, tracerProvider trace.TracerProvider) (policy *guardian.Policy, err error) { + // In local development, allow loopback addresses for internal tool-to-tool communication + if c.String("environment") == "local" { + policy, err = guardian.NewUnsafePolicy(tracerProvider, []string{}) // Allow all traffic for local development + if err != nil { + return nil, fmt.Errorf("failed to create unsafe http guardian policy: %w", err) + } + } else { + policy = guardian.NewDefaultPolicy(tracerProvider) + } + if s := c.StringSlice("disallowed-cidr-blocks"); s != nil { + var err error + policy, err = guardian.NewUnsafePolicy(tracerProvider, s) + if err != nil { + return nil, fmt.Errorf("failed to create unsafe http guardian policy: %w", err) + } + } + + return policy, nil +} + func newClickhouseClient(ctx context.Context, logger *slog.Logger, c *cli.Context) (clickhouse.Conn, func(context.Context) error, error) { logger = logger.With(attr.SlogComponent("clickhouse")) nilFunc := func(context.Context) error { return nil } @@ -377,6 +399,7 @@ func newBillingProvider( ctx context.Context, logger *slog.Logger, tracerProvider trace.TracerProvider, + guardianPolicy *guardian.Policy, redisClient *redis.Client, posthogClient *posthog.Posthog, c *cli.Context, @@ -395,7 +418,7 @@ func newBillingProvider( } polarAPIKey := c.String("polar-api-key") polarsdk := polargo.New(polargo.WithSecurity(polarAPIKey), polargo.WithTimeout(30*time.Second)) // Shouldn't take this long, but just in case - pclient := polar.NewClient(polarsdk, polarAPIKey, logger, tracerProvider, redisClient, catalog, c.String("polar-webhook-secret")) + pclient := polar.NewClient(guardianPolicy, polarsdk, polarAPIKey, logger, tracerProvider, redisClient, catalog, c.String("polar-webhook-secret")) tracker := tracking.New(pclient, posthogClient, logger) return pclient, tracker, nil case c.String("environment") == "local": @@ -484,6 +507,7 @@ func newFunctionOrchestrator( c *cli.Context, logger *slog.Logger, tracerProvider trace.TracerProvider, + guardianPolicy *guardian.Policy, db *pgxpool.Pool, assetStore assets.BlobStore, tigrisStore *assets.TigrisStore, @@ -544,6 +568,7 @@ func newFunctionOrchestrator( return functions.NewFlyRunner( logger, tracerProvider, + guardianPolicy, serverURL, db, assetStore, @@ -571,7 +596,7 @@ type mcpRegistryClientOptions struct { cacheImpl cache.Cache } -func newMCPRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, opts mcpRegistryClientOptions) (*externalmcp.RegistryClient, error) { +func newMCPRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, guardianProxy *guardian.Policy, opts mcpRegistryClientOptions) (*externalmcp.RegistryClient, error) { pulseURL, err := url.Parse("https://api.pulsemcp.com") if err != nil { return nil, fmt.Errorf("parse pulse registry url: %w", err) @@ -579,7 +604,7 @@ func newMCPRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvid backend := externalmcp.NewPulseBackend(pulseURL, opts.pulseTenantID, opts.pulseAPIKey) - return externalmcp.NewRegistryClient(logger, tracerProvider, backend, opts.cacheImpl), nil + return externalmcp.NewRegistryClient(logger, tracerProvider, guardianProxy, backend, opts.cacheImpl), nil } func newFeatureChecker(logger *slog.Logger, pf *productfeatures.Client, feat productfeatures.Feature) telemetry.FeatureChecker { diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index 49af450175..5c35890a97 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -51,7 +51,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/externalmcp" "github.com/speakeasy-api/gram/server/internal/feature" "github.com/speakeasy-api/gram/server/internal/functions" - "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/hooks" "github.com/speakeasy-api/gram/server/internal/instances" "github.com/speakeasy-api/gram/server/internal/integrations" @@ -401,6 +400,11 @@ func newStartCommand() *cli.Command { } shutdownFuncs = append(shutdownFuncs, shutdown) + guardianPolicy, err := newGuardianPolicy(c, tracerProvider) + if err != nil { + return err + } + db, err := newDBClient(ctx, logger, meterProvider, c.String("database-url"), dbClientOptions{ enableUnsafeLogging: c.Bool("unsafe-db-log"), }) @@ -455,7 +459,7 @@ func newStartCommand() *cli.Command { workosClient := workos.New(logger, c.String("workos-api-key")) - billingRepo, billingTracker, err := newBillingProvider(ctx, logger, tracerProvider, redisClient, posthogClient, c) + billingRepo, billingTracker, err := newBillingProvider(ctx, logger, tracerProvider, guardianPolicy, redisClient, posthogClient, c) if err != nil { return fmt.Errorf("failed to create billing provider: %w", err) } @@ -463,6 +467,7 @@ func newStartCommand() *cli.Command { sessionManager := sessions.NewManager( logger, tracerProvider, + guardianPolicy, db, redisClient, cache.SuffixNone, @@ -514,6 +519,8 @@ func newStartCommand() *cli.Command { } else { openRouter = openrouter.New( logger, + tracerProvider, + guardianPolicy, db, c.String("environment"), c.String("openrouter-provisioning-key"), @@ -558,38 +565,20 @@ func newStartCommand() *cli.Command { return fmt.Errorf("failed to parse site url: %w", err) } - // In local development, allow loopback addresses for internal tool-to-tool communication - var guardianPolicy *guardian.Policy - if c.String("environment") == "local" { - guardianPolicy, err = guardian.NewUnsafePolicy([]string{}) // Allow all traffic for local development - if err != nil { - return fmt.Errorf("failed to create unsafe http guardian policy: %w", err) - } - } else { - guardianPolicy = guardian.NewDefaultPolicy() - } - blockedCIDRs := c.StringSlice("disallowed-cidr-blocks") - if blockedCIDRs != nil { - guardianPolicy, err = guardian.NewUnsafePolicy(blockedCIDRs) - if err != nil { - return fmt.Errorf("failed to create unsafe http guardian policy: %w", err) - } - } - tigrisStore, shutdown, err := newTigrisStore(ctx, c, logger) if err != nil { return fmt.Errorf("failed to create tigris asset store: %w", err) } shutdownFuncs = append(shutdownFuncs, shutdown) - functionsOrchestrator, shutdown, err := newFunctionOrchestrator(c, logger, tracerProvider, db, assetStorage, tigrisStore, encryptionClient) + functionsOrchestrator, shutdown, err := newFunctionOrchestrator(c, logger, tracerProvider, guardianPolicy, db, assetStorage, tigrisStore, encryptionClient) if err != nil { return fmt.Errorf("failed to create functions orchestrator: %w", err) } shutdownFuncs = append(shutdownFuncs, shutdown) runnerVersion := functions.RunnerVersion(conv.Default(strings.TrimPrefix(c.String("functions-runner-version"), "sha-"), GitSHA)) - slackClient := slack_client.NewSlackClient("", "", db, encryptionClient) + slackClient := slack_client.NewSlackClient(guardianPolicy, "", "", db, encryptionClient) logsEnabled := newFeatureChecker(logger, productFeatures, productfeatures.FeatureLogs) toolIOLogsEnabled := newFeatureChecker(logger, productFeatures, productfeatures.FeatureToolIOLogs) @@ -598,6 +587,7 @@ func newStartCommand() *cli.Command { completionsClient := openrouter.NewUnifiedClient( logger, + guardianPolicy, openRouter, chat.NewChatMessageCaptureStrategy(logger, db, assetStorage), chat.NewDefaultUsageTrackingStrategy(db, logger, openRouter, billingTracker, &background.FallbackModelUsageTracker{TemporalEnv: temporalEnv}), @@ -607,7 +597,7 @@ func newStartCommand() *cli.Command { ) ragService := rag.NewToolsetVectorStore(logger, tracerProvider, db, completionsClient) - mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, mcpRegistryClientOptions{ + mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, guardianPolicy, mcpRegistryClientOptions{ pulseTenantID: c.String("pulse-registry-tenant"), pulseAPIKey: conv.NewSecret([]byte(c.String("pulse-registry-api-key"))), cacheImpl: cache.NewRedisCacheAdapter(redisClient), @@ -703,7 +693,7 @@ func newStartCommand() *cli.Command { toolsets.Attach(mux, toolsetsSvc) integrations.Attach(mux, integrations.NewService(logger, tracerProvider, db, sessionManager)) templates.Attach(mux, templates.NewService(logger, tracerProvider, db, sessionManager, toolsetsSvc)) - assets.Attach(mux, assets.NewService(logger, tracerProvider, db, sessionManager, chatSessionsManager, assetStorage, c.String("jwt-signing-key"))) + assets.Attach(mux, assets.NewService(logger, tracerProvider, guardianPolicy, db, sessionManager, chatSessionsManager, assetStorage, c.String("jwt-signing-key"))) deployments.Attach(mux, deployments.NewService(logger, tracerProvider, db, temporalEnv, sessionManager, assetStorage, posthogClient, siteURL, mcpRegistryClient)) keys.Attach(mux, keys.NewService(logger, tracerProvider, db, sessionManager, c.String("environment"))) chatsessionssvc.Attach(mux, chatsessionssvc.NewService(logger, tracerProvider, db, sessionManager, chatSessionsManager)) diff --git a/server/cmd/gram/worker.go b/server/cmd/gram/worker.go index b02ff754bc..429184c934 100644 --- a/server/cmd/gram/worker.go +++ b/server/cmd/gram/worker.go @@ -29,7 +29,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/environments" "github.com/speakeasy-api/gram/server/internal/feature" "github.com/speakeasy-api/gram/server/internal/functions" - "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/k8s" "github.com/speakeasy-api/gram/server/internal/mcp" mcpmetadata_repo "github.com/speakeasy-api/gram/server/internal/mcpmetadata/repo" @@ -327,6 +326,11 @@ func newWorkerCommand() *cli.Command { } shutdownFuncs = append(shutdownFuncs, shutdown) + guardianPolicy, err := newGuardianPolicy(c, tracerProvider) + if err != nil { + return err + } + db, err := newDBClient(ctx, logger, meterProvider, c.String("database-url"), dbClientOptions{ enableUnsafeLogging: c.Bool("unsafe-db-log"), }) @@ -398,7 +402,7 @@ func newWorkerCommand() *cli.Command { productFeatures := productfeatures.NewClient(logger, tracerProvider, db, redisClient) - billingRepo, billingTracker, err := newBillingProvider(ctx, logger, tracerProvider, redisClient, posthogClient, c) + billingRepo, billingTracker, err := newBillingProvider(ctx, logger, tracerProvider, guardianPolicy, redisClient, posthogClient, c) if err != nil { return fmt.Errorf("failed to create billing provider: %w", err) } @@ -407,25 +411,7 @@ func newWorkerCommand() *cli.Command { if c.String("environment") == "local" { openRouter = openrouter.NewDevelopment(c.String("openrouter-dev-key")) } else { - openRouter = openrouter.New(logger, db, c.String("environment"), c.String("openrouter-provisioning-key"), &background.OpenRouterKeyRefresher{TemporalEnv: temporalEnv}, productFeatures, billingTracker) - } - - // In local development, allow loopback addresses for internal tool-to-tool communication - var guardianPolicy *guardian.Policy - if c.String("environment") == "local" { - guardianPolicy, err = guardian.NewUnsafePolicy([]string{}) // Allow all traffic for local development - if err != nil { - return fmt.Errorf("failed to create unsafe http guardian policy: %w", err) - } - } else { - guardianPolicy = guardian.NewDefaultPolicy() - } - if s := c.StringSlice("disallowed-cidr-blocks"); s != nil { - var err error - guardianPolicy, err = guardian.NewUnsafePolicy(s) - if err != nil { - return fmt.Errorf("failed to create unsafe http guardian policy: %w", err) - } + openRouter = openrouter.New(logger, tracerProvider, guardianPolicy, db, c.String("environment"), c.String("openrouter-provisioning-key"), &background.OpenRouterKeyRefresher{TemporalEnv: temporalEnv}, productFeatures, billingTracker) } tigrisStore, shutdown, err := newTigrisStore(ctx, c, logger) @@ -434,7 +420,7 @@ func newWorkerCommand() *cli.Command { } shutdownFuncs = append(shutdownFuncs, shutdown) - functionsOrchestrator, shutdown, err := newFunctionOrchestrator(c, logger, tracerProvider, db, assetStorage, tigrisStore, encryptionClient) + functionsOrchestrator, shutdown, err := newFunctionOrchestrator(c, logger, tracerProvider, guardianPolicy, db, assetStorage, tigrisStore, encryptionClient) if err != nil { return fmt.Errorf("failed to create functions orchestrator: %w", err) } @@ -442,7 +428,7 @@ func newWorkerCommand() *cli.Command { runnerVersion := functions.RunnerVersion(conv.Default(strings.TrimPrefix(c.String("functions-runner-version"), "sha-"), GitSHA)) - slackClient := slack_client.NewSlackClient("", "", db, encryptionClient) + slackClient := slack_client.NewSlackClient(guardianPolicy, "", "", db, encryptionClient) logsEnabled := newFeatureChecker(logger, productFeatures, productfeatures.FeatureLogs) toolIOLogsEnabled := newFeatureChecker(logger, productFeatures, productfeatures.FeatureToolIOLogs) @@ -462,6 +448,7 @@ func newWorkerCommand() *cli.Command { completionsClient := openrouter.NewUnifiedClient( logger, + guardianPolicy, openRouter, chat.NewChatMessageCaptureStrategy(logger, db, assetStorage), chat.NewDefaultUsageTrackingStrategy(db, logger, openRouter, billingTracker, &background.FallbackModelUsageTracker{TemporalEnv: temporalEnv}), @@ -471,7 +458,7 @@ func newWorkerCommand() *cli.Command { ) ragService := rag.NewToolsetVectorStore(logger, tracerProvider, db, completionsClient) - mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, mcpRegistryClientOptions{ + mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, guardianPolicy, mcpRegistryClientOptions{ pulseTenantID: c.String("pulse-registry-tenant"), pulseAPIKey: conv.NewSecret([]byte(c.String("pulse-registry-api-key"))), cacheImpl: cache.NewRedisCacheAdapter(redisClient), @@ -490,7 +477,7 @@ func newWorkerCommand() *cli.Command { return fmt.Errorf("failed to create pylon client: %w", err) } - sessionManager := sessions.NewManager(logger, tracerProvider, db, redisClient, cache.SuffixNone, c.String("speakeasy-server-address"), c.String("speakeasy-secret-key"), pylonClient, posthogClient, billingRepo, nil) + sessionManager := sessions.NewManager(logger, tracerProvider, guardianPolicy, db, redisClient, cache.SuffixNone, c.String("speakeasy-server-address"), c.String("speakeasy-secret-key"), pylonClient, posthogClient, billingRepo, nil) chatSessionsManager := chatsessions.NewManager(logger, redisClient, c.String("jwt-signing-key")) diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 3babe891ff..164e7aaae2 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -15,6 +15,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/guardian" orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/urn" @@ -55,6 +56,8 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -64,7 +67,7 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) authCtx, ok := contextvalues.GetAuthContext(ctx) diff --git a/server/internal/agentworkflows/agents/setup_test.go b/server/internal/agentworkflows/agents/setup_test.go index dd6e1f5f94..ee2516bc2a 100644 --- a/server/internal/agentworkflows/agents/setup_test.go +++ b/server/internal/agentworkflows/agents/setup_test.go @@ -25,6 +25,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/deployments" "github.com/speakeasy-api/gram/server/internal/environments" "github.com/speakeasy-api/gram/server/internal/feature" + "github.com/speakeasy-api/gram/server/internal/guardian" mcpmetadata_repo "github.com/speakeasy-api/gram/server/internal/mcpmetadata/repo" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -75,6 +76,8 @@ func newTestAgentsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "agentstest") require.NoError(t, err) @@ -102,7 +105,7 @@ func newTestAgentsService(t *testing.T) (context.Context, *testInstance) { posthogClient := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -124,7 +127,7 @@ func newTestAgentsService(t *testing.T) (context.Context, *testInstance) { env, enc, cacheImpl, - nil, // guardian policy + guardianPolicy, funcs, nil, // openrouter provisioner nil, // chat client @@ -133,7 +136,7 @@ func newTestAgentsService(t *testing.T) (context.Context, *testInstance) { // Create supporting services toolsetsSvc := toolsets.NewService(logger, tracerProvider, conn, sessionManager, nil) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporal, sessionManager, assetStorage, posthogClient, testenv.DefaultSiteURL(t), mcpRegistryClient) - assetsSvc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") + assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") return ctx, &testInstance{ agentsService: agentsService, diff --git a/server/internal/agentworkflows/setup_test.go b/server/internal/agentworkflows/setup_test.go index d1bad32ecf..b43303d9f6 100644 --- a/server/internal/agentworkflows/setup_test.go +++ b/server/internal/agentworkflows/setup_test.go @@ -17,6 +17,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/environments" + "github.com/speakeasy-api/gram/server/internal/guardian" mcpmetadata_repo "github.com/speakeasy-api/gram/server/internal/mcpmetadata/repo" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -60,6 +61,8 @@ func newTestAgentsAPIService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "agentsapitest") require.NoError(t, err) @@ -69,7 +72,7 @@ func newTestAgentsAPIService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) @@ -97,7 +100,7 @@ func newTestAgentsAPIService(t *testing.T) (context.Context, *testInstance) { env, enc, cacheImpl, - nil, // guardian policy - nil is acceptable for testing + guardianPolicy, funcs, nil, // openrouter provisioner - nil is acceptable for testing nil, // chat client - nil is acceptable for testing @@ -123,7 +126,7 @@ func newTestAgentsAPIService(t *testing.T) (context.Context, *testInstance) { env, enc, cacheImpl, - nil, // guardian policy + guardianPolicy, funcs, nil, // openrouter provisioner nil, // chat client diff --git a/server/internal/assets/impl.go b/server/internal/assets/impl.go index 01b2bd68bd..e1f7b691e2 100644 --- a/server/internal/assets/impl.go +++ b/server/internal/assets/impl.go @@ -34,6 +34,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/chatsessions" "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/inv" "github.com/speakeasy-api/gram/server/internal/middleware" "github.com/speakeasy-api/gram/server/internal/o11y" @@ -51,12 +52,13 @@ const ( ) type Service struct { - tracer trace.Tracer - logger *slog.Logger - db *pgxpool.Pool - auth *auth.Auth - storage BlobStore - jwtSecret string + tracer trace.Tracer + logger *slog.Logger + guardianPolicy *guardian.Policy + db *pgxpool.Pool + auth *auth.Auth + storage BlobStore + jwtSecret string chatSessions *chatsessions.Manager projects *projectsRepo.Queries @@ -66,19 +68,20 @@ type Service struct { var _ gen.Service = (*Service)(nil) var _ gen.Auther = (*Service)(nil) -func NewService(logger *slog.Logger, tracerProvider trace.TracerProvider, db *pgxpool.Pool, sessions *sessions.Manager, chatSessions *chatsessions.Manager, storage BlobStore, jwtSecret string) *Service { +func NewService(logger *slog.Logger, tracerProvider trace.TracerProvider, guardianPolicy *guardian.Policy, db *pgxpool.Pool, sessions *sessions.Manager, chatSessions *chatsessions.Manager, storage BlobStore, jwtSecret string) *Service { logger = logger.With(attr.SlogComponent("assets")) return &Service{ - tracer: tracerProvider.Tracer("github.com/speakeasy-api/gram/server/internal/assets"), - logger: logger, - db: db, - auth: auth.New(logger, db, sessions), - storage: storage, - jwtSecret: jwtSecret, - chatSessions: chatSessions, - projects: projectsRepo.New(db), - repo: repo.New(db), + tracer: tracerProvider.Tracer("github.com/speakeasy-api/gram/server/internal/assets"), + logger: logger, + guardianPolicy: guardianPolicy, + db: db, + auth: auth.New(logger, db, sessions), + storage: storage, + jwtSecret: jwtSecret, + chatSessions: chatSessions, + projects: projectsRepo.New(db), + repo: repo.New(db), } } @@ -800,9 +803,8 @@ func (s *Service) FetchOpenAPIv3FromURL(ctx context.Context, payload *gen.FetchO return nil, oops.E(oops.CodeUnexpected, fmt.Errorf("create request: %w", err), "error fetching URL") } - client := &http.Client{ - Timeout: 30 * time.Second, - } + client := s.guardianPolicy.Client() + client.Timeout = 30 * time.Second resp, err := client.Do(req) if err != nil { return nil, oops.E(oops.CodeBadRequest, fmt.Errorf("fetch url: %w", err), "error fetching URL") diff --git a/server/internal/assets/setup_test.go b/server/internal/assets/setup_test.go index 0bb548cde7..f07d3705de 100644 --- a/server/internal/assets/setup_test.go +++ b/server/internal/assets/setup_test.go @@ -16,6 +16,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -58,6 +59,8 @@ func newTestAssetsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -67,7 +70,7 @@ func newTestAssetsService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -75,7 +78,7 @@ func newTestAssetsService(t *testing.T) (context.Context, *testInstance) { ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) - svc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, storage, "test-jwt-secret") + svc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, storage, "test-jwt-secret") repository := repo.New(conn) return ctx, &testInstance{ diff --git a/server/internal/audit/setup_test.go b/server/internal/audit/setup_test.go index c292f0a3d6..f71bac1a6e 100644 --- a/server/internal/audit/setup_test.go +++ b/server/internal/audit/setup_test.go @@ -13,6 +13,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -50,6 +51,8 @@ func newTestAuditService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -58,7 +61,7 @@ func newTestAuditService(t *testing.T) (context.Context, *testInstance) { require.NoError(t, err) billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/auth/sessions/sessions.go b/server/internal/auth/sessions/sessions.go index 6b25bb8e97..23b5df39a1 100644 --- a/server/internal/auth/sessions/sessions.go +++ b/server/internal/auth/sessions/sessions.go @@ -4,19 +4,17 @@ import ( "context" "fmt" "log/slog" - "net/http" "time" - "github.com/hashicorp/go-cleanhttp" "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" "github.com/speakeasy-api/gram/server/internal/attr" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/mv" "github.com/speakeasy-api/gram/server/internal/oops" orgRepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" @@ -33,7 +31,7 @@ type Manager struct { userInfoCache cache.TypedCacheObject[CachedUserInfo] speakeasyServerAddress string speakeasySecretKey string - speakeasyClient *http.Client + speakeasyClient *guardian.HTTPClient orgRepo *orgRepo.Queries userRepo *userRepo.Queries pylon *pylon.Pylon @@ -45,6 +43,7 @@ type Manager struct { func NewManager( logger *slog.Logger, tracerProvider trace.TracerProvider, + guardianPolicy *guardian.Policy, db *pgxpool.Pool, redisClient *redis.Client, suffix cache.Suffix, @@ -56,13 +55,8 @@ func NewManager( workos *workos.WorkOS, ) *Manager { logger = logger.With(attr.SlogComponent("sessions")) - speakeasyClient := &http.Client{ - Timeout: 10 * time.Second, - Transport: otelhttp.NewTransport( - cleanhttp.DefaultPooledTransport(), - otelhttp.WithTracerProvider(tracerProvider), - ), - } + speakeasyClient := guardianPolicy.PooledClient() + speakeasyClient.Timeout = 10 * time.Second return &Manager{ logger: logger.With(attr.SlogComponent("sessions")), diff --git a/server/internal/auth/setup_test.go b/server/internal/auth/setup_test.go index 01b20141d0..5a2bada9af 100644 --- a/server/internal/auth/setup_test.go +++ b/server/internal/auth/setup_test.go @@ -18,6 +18,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/guardian" orgRepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/posthog" @@ -277,6 +278,8 @@ func newTestAuthService(t *testing.T, userInfo *MockUserInfo) (context.Context, ctx := t.Context() logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "authtest") require.NoError(t, err) @@ -295,7 +298,7 @@ func newTestAuthService(t *testing.T, userInfo *MockUserInfo) (context.Context, billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := sessions.NewManager(logger, testenv.NewTracerProvider(t), conn, redisClient, cache.Suffix("gram-test"), mockServer.URL, "test-secret-key", pylon, posthog, billingClient, nil) + sessionManager := sessions.NewManager(logger, testenv.NewTracerProvider(t), guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), mockServer.URL, "test-secret-key", pylon, posthog, billingClient, nil) authConfigs := auth.AuthConfigurations{ SpeakeasyServerAddress: mockServer.URL, diff --git a/server/internal/chat/client.go b/server/internal/chat/client.go index 0da3e84974..caee067a29 100644 --- a/server/internal/chat/client.go +++ b/server/internal/chat/client.go @@ -6,12 +6,15 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/speakeasy-api/gram/server/internal/assets" "github.com/speakeasy-api/gram/server/internal/billing" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/telemetry" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/thirdparty/openrouter" ) -func NewBaseChatClient(logger *slog.Logger, +func NewBaseChatClient( + logger *slog.Logger, + guardianPolicy *guardian.Policy, db *pgxpool.Pool, openRouter openrouter.Provisioner, temporalEnv *temporal.Environment, @@ -42,6 +45,7 @@ func NewBaseChatClient(logger *slog.Logger, // Create UnifiedClient with strategies (after telemSvc is available) return openrouter.NewUnifiedClient( logger, + guardianPolicy, openRouter, messageCaptureStrategy, usageTrackingStrategy, diff --git a/server/internal/deployments/setup_test.go b/server/internal/deployments/setup_test.go index 618854697e..8243e66b7e 100644 --- a/server/internal/deployments/setup_test.go +++ b/server/internal/deployments/setup_test.go @@ -17,6 +17,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/deployments" "github.com/speakeasy-api/gram/server/internal/feature" + "github.com/speakeasy-api/gram/server/internal/guardian" packages "github.com/speakeasy-api/gram/server/internal/packages" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -64,6 +65,8 @@ func newTestDeploymentService(t *testing.T, assetStorage assets.BlobStore) (cont logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -86,7 +89,7 @@ func newTestDeploymentService(t *testing.T, assetStorage assets.BlobStore) (cont billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -95,7 +98,7 @@ func newTestDeploymentService(t *testing.T, assetStorage assets.BlobStore) (cont posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") svc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient) - assetsSvc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") + assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") packagesSvc := packages.NewService(logger, tracerProvider, conn, sessionManager) return ctx, &testInstance{ diff --git a/server/internal/environments/setup_test.go b/server/internal/environments/setup_test.go index e132e8dbc9..a6fd5e1fce 100644 --- a/server/internal/environments/setup_test.go +++ b/server/internal/environments/setup_test.go @@ -13,6 +13,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/environments" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -52,6 +53,8 @@ func newTestEnvironmentService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -61,7 +64,7 @@ func newTestEnvironmentService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/externalmcp/registryclient.go b/server/internal/externalmcp/registryclient.go index 4729221e0a..e1cbecbf1c 100644 --- a/server/internal/externalmcp/registryclient.go +++ b/server/internal/externalmcp/registryclient.go @@ -12,14 +12,13 @@ import ( "strings" "github.com/google/uuid" - "github.com/hashicorp/go-retryablehttp" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" "github.com/speakeasy-api/gram/server/gen/types" "github.com/speakeasy-api/gram/server/internal/attr" "github.com/speakeasy-api/gram/server/internal/cache" externalmcptypes "github.com/speakeasy-api/gram/server/internal/externalmcp/repo/types" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" ) @@ -31,7 +30,7 @@ type RegistryBackend interface { // RegistryClient handles communication with external MCP registries. type RegistryClient struct { - httpClient *http.Client + httpClient *guardian.HTTPClient logger *slog.Logger backend RegistryBackend listCache *cache.TypedCacheObject[CachedListServersResponse] @@ -40,14 +39,9 @@ type RegistryClient struct { // NewRegistryClient creates a new registry client. The cacheImpl parameter is // optional — pass nil to disable caching. -func NewRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, backend RegistryBackend, cacheImpl cache.Cache) *RegistryClient { +func NewRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, guardianPolicy *guardian.Policy, backend RegistryBackend, cacheImpl cache.Cache) *RegistryClient { rc := &RegistryClient{ - httpClient: &http.Client{ - Transport: otelhttp.NewTransport( - retryablehttp.NewClient().StandardClient().Transport, - otelhttp.WithTracerProvider(tracerProvider), - ), - }, + httpClient: guardianPolicy.PooledClient(), logger: logger.With(attr.SlogComponent("mcp_registry_client")), backend: backend, listCache: nil, diff --git a/server/internal/externalmcp/registryclient_test.go b/server/internal/externalmcp/registryclient_test.go index cff427fd4a..925d198fdb 100644 --- a/server/internal/externalmcp/registryclient_test.go +++ b/server/internal/externalmcp/registryclient_test.go @@ -9,11 +9,12 @@ import ( "testing" "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" tracernoop "go.opentelemetry.io/otel/trace/noop" "github.com/speakeasy-api/gram/server/internal/externalmcp/repo/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/speakeasy-api/gram/server/internal/guardian" ) type PassthroughBackend struct{} @@ -32,6 +33,9 @@ func TestListServers_FiltersDeletedServers(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := listResponse{ @@ -82,7 +86,7 @@ func TestListServers_FiltersDeletedServers(t *testing.T) { })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), @@ -101,6 +105,9 @@ func TestGetServerDetails_OnlyStreamableHTTP(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodGet, r.Method) @@ -124,7 +131,7 @@ func TestGetServerDetails_OnlyStreamableHTTP(t *testing.T) { })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), @@ -146,6 +153,9 @@ func TestGetServerDetails_OnlySSE(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodGet, r.Method) @@ -169,7 +179,7 @@ func TestGetServerDetails_OnlySSE(t *testing.T) { })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), @@ -191,6 +201,9 @@ func TestGetServerDetails_PrefersStreamableHTTPOverSSE(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodGet, r.Method) @@ -215,7 +228,7 @@ func TestGetServerDetails_PrefersStreamableHTTPOverSSE(t *testing.T) { })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), @@ -237,6 +250,9 @@ func TestGetServerDetails_SelectedRemotesFiltersToSSE(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodGet, r.Method) @@ -262,7 +278,7 @@ func TestGetServerDetails_SelectedRemotesFiltersToSSE(t *testing.T) { })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), @@ -285,6 +301,9 @@ func TestGetServerDetails_SelectedRemotesStillPrefersStreamableHTTP(t *testing.T t.Parallel() ctx := context.Background() logger := slog.New(slog.DiscardHandler) + tracerProvider := tracernoop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodGet, r.Method) @@ -310,7 +329,7 @@ func TestGetServerDetails_SelectedRemotesStillPrefersStreamableHTTP(t *testing.T })) defer server.Close() - client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil) + client := NewRegistryClient(logger, tracerProvider, guardianPolicy, &PassthroughBackend{}, nil) client.httpClient = server.Client() registry := Registry{ ID: uuid.New(), diff --git a/server/internal/functions/deploy_fly.go b/server/internal/functions/deploy_fly.go index eab69a4d6f..7e60d6ba60 100644 --- a/server/internal/functions/deploy_fly.go +++ b/server/internal/functions/deploy_fly.go @@ -18,14 +18,12 @@ import ( backoff "github.com/cenkalti/backoff/v5" "github.com/google/uuid" - "github.com/hashicorp/go-retryablehttp" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" slogmulti "github.com/samber/slog-multi" "github.com/superfly/fly-go" "github.com/superfly/fly-go/flaps" "github.com/superfly/fly-go/tokens" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" "github.com/speakeasy-api/gram/server/internal/assets" @@ -34,6 +32,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/deployments/events" "github.com/speakeasy-api/gram/server/internal/encryption" "github.com/speakeasy-api/gram/server/internal/functions/repo" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/inv" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" @@ -68,7 +67,7 @@ type FlyRunner struct { client *fly.Client tokens *tokens.Tokens machinesAPIBase string - machinesClient *http.Client + machinesClient *guardian.HTTPClient defaultOrg string defaultRegion string imgSelector ImageSelector @@ -84,6 +83,7 @@ var _ interface { func NewFlyRunner( logger *slog.Logger, tracerProvider trace.TracerProvider, + guardianPolicy *guardian.Policy, serverURL *url.URL, db *pgxpool.Pool, assetStorage assets.BlobStore, @@ -96,12 +96,9 @@ func NewFlyRunner( flyAPIBase := conv.Default(o.FlyAPIURL, defaultFlyBaseURL) machinesAPIBase := conv.Default(o.FlyMachinesBaseURL, defaultFlyMachinesURL) - machinesClient := &http.Client{ - Transport: otelhttp.NewTransport( - retryablehttp.NewClient().StandardClient().Transport, - otelhttp.WithTracerProvider(tracerProvider), - ), - } + machinesClient := guardianPolicy.PooledClient( + guardian.WithRetryConfig(guardian.DefaultRetryConfig()), + ) c := fly.NewClientFromOptions(fly.ClientOptions{ BaseURL: flyAPIBase, @@ -109,10 +106,9 @@ func NewFlyRunner( Name: o.ServiceName, Version: o.ServiceVersion, Transport: &fly.Transport{ - UnderlyingTransport: otelhttp.NewTransport( - retryablehttp.NewClient().StandardClient().Transport, - otelhttp.WithTracerProvider(tracerProvider), - ), + UnderlyingTransport: guardianPolicy.PooledClient( + guardian.WithRetryConfig(guardian.DefaultRetryConfig()), + ).Transport, UserAgent: ua, Tokens: o.FlyTokens, }, diff --git a/server/internal/functions/deploy_local.go b/server/internal/functions/deploy_local.go index 3673b8737b..6566cfb3a3 100644 --- a/server/internal/functions/deploy_local.go +++ b/server/internal/functions/deploy_local.go @@ -7,23 +7,19 @@ import ( "net/http" "os" - "github.com/hashicorp/go-retryablehttp" - "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/urn" ) type LocalRunner struct { - codeRoot *os.Root - toolcallClient *http.Client + codeRoot *os.Root } var _ Orchestrator = (*LocalRunner)(nil) func NewLocalRunner(codeRoot *os.Root) *LocalRunner { return &LocalRunner{ - codeRoot: codeRoot, - toolcallClient: retryablehttp.NewClient().StandardClient(), + codeRoot: codeRoot, } } diff --git a/server/internal/functions/setup_test.go b/server/internal/functions/setup_test.go index a139f77d04..ceed0930e0 100644 --- a/server/internal/functions/setup_test.go +++ b/server/internal/functions/setup_test.go @@ -23,6 +23,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/deployments" "github.com/speakeasy-api/gram/server/internal/feature" "github.com/speakeasy-api/gram/server/internal/functions" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/posthog" @@ -70,6 +71,8 @@ func newTestFunctionsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -95,7 +98,7 @@ func newTestFunctionsService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -105,7 +108,7 @@ func newTestFunctionsService(t *testing.T) (context.Context, *testInstance) { svc := functions.NewService(logger, tracerProvider, conn, enc, tigrisStore) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, ph, testenv.DefaultSiteURL(t), mcpRegistryClient) - assetsSvc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") + assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") return ctx, &testInstance{ service: svc, diff --git a/server/internal/gateway/proxy.go b/server/internal/gateway/proxy.go index e9f44111f8..65f6e940f9 100644 --- a/server/internal/gateway/proxy.go +++ b/server/internal/gateway/proxy.go @@ -542,7 +542,7 @@ func (tp *ToolProxy) doHTTP( } } - shouldContinue := processSecurity(ctx, logger, req, w, &responseStatusCode, descriptor, plan, tp.cache, env, serverURL, attrRecorder) + shouldContinue := processSecurity(ctx, logger, tp.policy, req, w, &responseStatusCode, descriptor, plan, tp.cache, env, serverURL, attrRecorder) if !shouldContinue { return nil } diff --git a/server/internal/gateway/proxy_test.go b/server/internal/gateway/proxy_test.go index 43db4b35ca..b27e24b49a 100644 --- a/server/internal/gateway/proxy_test.go +++ b/server/internal/gateway/proxy_test.go @@ -189,7 +189,7 @@ func TestToolProxy_Do_PathParams(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -318,7 +318,7 @@ func TestToolProxy_Do_HeaderParams(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -678,7 +678,7 @@ func TestToolProxy_Do_QueryParams(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -887,7 +887,7 @@ func TestToolProxy_Do_Body(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) // Create tool configuration @@ -1253,7 +1253,7 @@ func TestToolProxy_Do_StringifiedJSONBody(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1336,7 +1336,7 @@ func TestResourceProxy_ReadResource(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) // Create resource descriptor @@ -1476,7 +1476,7 @@ func TestToolProxy_Do_FunctionMetricsTrailers(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1570,7 +1570,7 @@ func TestToolProxy_Do_HTTPTool_UserConfigVariablesSent(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1663,7 +1663,7 @@ func TestToolProxy_Do_HTTPTool_UserConfigNotInPlanNotSent(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1767,7 +1767,7 @@ func TestToolProxy_Do_FunctionTool_UserConfigNotInPlanNotSent(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1862,7 +1862,7 @@ func TestToolProxy_Do_HTTPTool_SystemEnvSentWhenInPlan(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -1956,7 +1956,7 @@ func TestToolProxy_Do_HTTPTool_SystemEnvKeysConvertedToHTTPHeaders(t *testing.T) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2056,7 +2056,7 @@ func TestToolProxy_Do_FunctionTool_SystemEnvSentWhenInPlan(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2153,7 +2153,7 @@ func TestToolProxy_Do_HTTPTool_UserConfigPrefersOverSystemEnv(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2266,7 +2266,7 @@ func TestToolProxy_Do_FunctionTool_UserConfigPrefersOverSystemEnv(t *testing.T) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2380,7 +2380,7 @@ func TestToolProxy_Do_FunctionTool_AuthInputSentWhenInUserConfig(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2491,7 +2491,7 @@ func TestToolProxy_Do_FunctionTool_AuthInputNotSentWhenNotInUserConfig(t *testin tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2602,7 +2602,7 @@ func TestToolProxy_Do_FunctionTool_AuthInputPrefersUserConfigOverSystemEnv(t *te tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2716,7 +2716,7 @@ func TestToolProxy_Do_FunctionTool_AuthInputSentWithRegularVariables(t *testing. tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() @@ -2834,7 +2834,7 @@ func TestToolProxy_Do_FunctionTool_AuthInputNilNotSent(t *testing.T) { tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) enc := testenv.NewEncryptionClient(t) - policy, err := guardian.NewUnsafePolicy([]string{}) + policy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) require.NoError(t, err) tool := newTestToolDescriptor() diff --git a/server/internal/gateway/security.go b/server/internal/gateway/security.go index a8b7767090..57b6798a6e 100644 --- a/server/internal/gateway/security.go +++ b/server/internal/gateway/security.go @@ -17,6 +17,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/attr" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" tm "github.com/speakeasy-api/gram/server/internal/telemetry" "github.com/speakeasy-api/gram/server/internal/toolconfig" ) @@ -24,6 +25,7 @@ import ( func processSecurity( ctx context.Context, logger *slog.Logger, + guardianPolicy *guardian.Policy, req *http.Request, w http.ResponseWriter, responseStatusCodeCapture *int, @@ -145,7 +147,7 @@ func processSecurity( } } case "client_credentials": - token, err := processClientCredentials(ctx, logger, req, cacheImpl, tool, plan.SecurityScopes, security, mergedEnv, serverURL) + token, err := processClientCredentials(ctx, logger, guardianPolicy, req, cacheImpl, tool, plan.SecurityScopes, security, mergedEnv, serverURL) if err != nil { logger.ErrorContext(ctx, "could not process client credentials", attr.SlogError(err)) if strings.Contains(err.Error(), "failed to make client credentials token request") { @@ -245,7 +247,7 @@ type clientCredentialsTokenResponseCamelCase struct { ExpiresIn int `json:"expiresIn"` } -func processClientCredentials(ctx context.Context, logger *slog.Logger, req *http.Request, cacheImpl cache.Cache, tool *ToolDescriptor, planScopes map[string][]string, security *HTTPToolSecurity, mergedEnv *toolconfig.CaseInsensitiveEnv, serverURL string) (string, error) { +func processClientCredentials(ctx context.Context, logger *slog.Logger, guardianPolicy *guardian.Policy, req *http.Request, cacheImpl cache.Cache, tool *ToolDescriptor, planScopes map[string][]string, security *HTTPToolSecurity, mergedEnv *toolconfig.CaseInsensitiveEnv, serverURL string) (string, error) { // To discuss, currently we are taking the approach of exact scope match for reused tokens // We could look into enabling a prefix match feature for caches where we return multiple entries matching the projectID, clientID, tokenURL and then check scopes against all returned values // We would want to make sure any underlying cache implementation supports this feature @@ -327,13 +329,8 @@ func processClientCredentials(ctx context.Context, logger *slog.Logger, req *htt tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Make the token request - client := &http.Client{ - Timeout: 10 * time.Second, - Transport: otelhttp.NewTransport( - http.DefaultTransport, - otelhttp.WithPropagators(propagation.TraceContext{}), - ), - } + client := guardianPolicy.Client(guardian.WithOTelHTTPOptions(otelhttp.WithPropagators(propagation.TraceContext{}))) + client.Timeout = 10 * time.Second resp, err := client.Do(tokenReq) if err != nil { return "", fmt.Errorf("failed to make client credentials token request: %w", err) @@ -429,7 +426,7 @@ func parseClientCredentialsTokenResponse(body []byte) (string, int, error) { return accessToken, expiresIn, nil } -func retryTokenRequestWithBasicAuth(ctx context.Context, client *http.Client, tokenURL, clientID, clientSecret string, requestedScopes []string) (*http.Response, error) { +func retryTokenRequestWithBasicAuth(ctx context.Context, client *guardian.HTTPClient, tokenURL, clientID, clientSecret string, requestedScopes []string) (*http.Response, error) { values := url.Values{} values.Set("grant_type", "client_credentials") if len(requestedScopes) > 0 { diff --git a/server/internal/guardian/policy.go b/server/internal/guardian/policy.go index da32847f0f..b9a2fcfd42 100644 --- a/server/internal/guardian/policy.go +++ b/server/internal/guardian/policy.go @@ -9,8 +9,13 @@ import ( "time" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-retryablehttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/trace" ) +type HTTPClient = http.Client + var ( ErrBadHost = fmt.Errorf("bad host") ErrBlockedIP = fmt.Errorf("blocked ip") @@ -52,14 +57,57 @@ var defaultBlockedCIDRBlocks = []*net.IPNet{ mustParseCIDR("2001:20::/28"), /* ORCHIDv2 - RFC7343 */ } +type RetryConfig struct { + RetryWaitMin time.Duration + RetryWaitMax time.Duration + RetryMax int + CheckRetry retryablehttp.CheckRetry + Backoff retryablehttp.Backoff + ErrorHandler retryablehttp.ErrorHandler + PrepareRetry retryablehttp.PrepareRetry +} + +var defaultRetryClient = retryablehttp.NewClient() + +func DefaultRetryConfig() *RetryConfig { + return &RetryConfig{ + RetryWaitMin: defaultRetryClient.RetryWaitMin, + RetryWaitMax: defaultRetryClient.RetryWaitMax, + RetryMax: defaultRetryClient.RetryMax, + CheckRetry: defaultRetryClient.CheckRetry, + Backoff: defaultRetryClient.Backoff, + ErrorHandler: defaultRetryClient.ErrorHandler, + PrepareRetry: defaultRetryClient.PrepareRetry, + } +} + +type htttpClientOptions struct { + otelHTTPOptions []otelhttp.Option + retryConfig *RetryConfig +} + +func WithOTelHTTPOptions(options ...otelhttp.Option) func(*htttpClientOptions) { + return func(o *htttpClientOptions) { + o.otelHTTPOptions = options + } +} + +func WithRetryConfig(config *RetryConfig) func(*htttpClientOptions) { + return func(o *htttpClientOptions) { + o.retryConfig = config + } +} + type Policy struct { + tracerProvider trace.TracerProvider blockedCIDRBlocks []*net.IPNet } // NewDefaultPolicy creates a new Policy that blocks common private and reserved // IP ranges. -func NewDefaultPolicy() *Policy { +func NewDefaultPolicy(tracerProvider trace.TracerProvider) *Policy { return &Policy{ + tracerProvider: tracerProvider, blockedCIDRBlocks: defaultBlockedCIDRBlocks, } } @@ -68,7 +116,7 @@ func NewDefaultPolicy() *Policy { // It returns an error if any of the CIDR blocks cannot be parsed. // Use NewDefaultPolicy for a safe default that blocks common private and // reserved IP ranges. -func NewUnsafePolicy(disallowedCIDRBlocks []string) (*Policy, error) { +func NewUnsafePolicy(tracerProvider trace.TracerProvider, disallowedCIDRBlocks []string) (*Policy, error) { var disallowedBlocks []*net.IPNet for _, cidr := range disallowedCIDRBlocks { block, err := parseCIDR(cidr) @@ -78,21 +126,50 @@ func NewUnsafePolicy(disallowedCIDRBlocks []string) (*Policy, error) { disallowedBlocks = append(disallowedBlocks, block) } - return &Policy{blockedCIDRBlocks: disallowedBlocks}, nil + return &Policy{ + tracerProvider: tracerProvider, + blockedCIDRBlocks: disallowedBlocks, + }, nil } -func (p *Policy) PooledClient() *http.Client { - t := cleanhttp.DefaultPooledTransport() - t.DialContext = p.Dialer().DialContext +func (p *Policy) PooledClient(options ...func(*htttpClientOptions)) *HTTPClient { + return p.clientWithBaseTransport(cleanhttp.DefaultPooledTransport(), options...) +} - return &http.Client{Transport: t} +func (p *Policy) Client(options ...func(*htttpClientOptions)) *HTTPClient { + return p.clientWithBaseTransport(cleanhttp.DefaultTransport(), options...) } -func (p *Policy) Client() *http.Client { - t := cleanhttp.DefaultTransport() - t.DialContext = p.Dialer().DialContext +func (p *Policy) clientWithBaseTransport(transport *http.Transport, options ...func(*htttpClientOptions)) *HTTPClient { + var opts htttpClientOptions + for _, option := range options { + option(&opts) + } + + transport.DialContext = p.Dialer().DialContext + + otelOpts := []otelhttp.Option{otelhttp.WithTracerProvider(p.tracerProvider)} + otelOpts = append(otelOpts, opts.otelHTTPOptions...) + otelTransport := otelhttp.NewTransport(transport, otelOpts...) + + if opts.retryConfig == nil { + return &http.Client{Transport: otelTransport} + } + + retryClient := retryablehttp.NewClient() + retryClient.HTTPClient = &http.Client{ + Transport: otelTransport, + } + + retryClient.RetryWaitMin = opts.retryConfig.RetryWaitMin + retryClient.RetryWaitMax = opts.retryConfig.RetryWaitMax + retryClient.RetryMax = opts.retryConfig.RetryMax + retryClient.CheckRetry = opts.retryConfig.CheckRetry + retryClient.Backoff = opts.retryConfig.Backoff + retryClient.ErrorHandler = opts.retryConfig.ErrorHandler + retryClient.PrepareRetry = opts.retryConfig.PrepareRetry - return &http.Client{Transport: t} + return retryClient.StandardClient() } func (p *Policy) Dialer() *net.Dialer { diff --git a/server/internal/guardian/policy_test.go b/server/internal/guardian/policy_test.go index 8f85f3a5ac..9b5ad79b0b 100644 --- a/server/internal/guardian/policy_test.go +++ b/server/internal/guardian/policy_test.go @@ -10,7 +10,9 @@ import ( "time" "github.com/speakeasy-api/gram/server/internal/guardian" + "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/stretchr/testify/require" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) func TestNewUnsafePolicy(t *testing.T) { @@ -50,7 +52,7 @@ func TestNewUnsafePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - policy, err := guardian.NewUnsafePolicy(tt.cidrBlocks) + policy, err := guardian.NewUnsafePolicy(testenv.NewTracerProvider(t), tt.cidrBlocks) if tt.expectError { require.Error(t, err) require.Nil(t, policy) @@ -64,7 +66,7 @@ func TestNewUnsafePolicy(t *testing.T) { func TestPolicy_Dialer(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() require.NotNil(t, dialer) @@ -75,7 +77,7 @@ func TestPolicy_Dialer(t *testing.T) { func TestPolicy_DialerControlContext(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() @@ -128,7 +130,7 @@ func TestPolicy_DialerControlContext(t *testing.T) { func TestPolicy_DialerControlContext_CustomPolicy(t *testing.T) { t.Parallel() - policy, err := guardian.NewUnsafePolicy([]string{"8.8.8.0/24"}) + policy, err := guardian.NewUnsafePolicy(testenv.NewTracerProvider(t), []string{"8.8.8.0/24"}) require.NoError(t, err) dialer := policy.Dialer() @@ -142,35 +144,33 @@ func TestPolicy_DialerControlContext_CustomPolicy(t *testing.T) { require.NoError(t, err) } -func TestPolicy_Client(t *testing.T) { +func TestPolicy_ClientWrapsTransportWithOtel(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) client := policy.Client() require.NotNil(t, client) require.NotNil(t, client.Transport) - transport, ok := client.Transport.(*http.Transport) + _, ok := client.Transport.(*otelhttp.Transport) require.True(t, ok) - require.NotNil(t, transport.DialContext) } -func TestPolicy_PooledClient(t *testing.T) { +func TestPolicy_PooledClientWrapsTransportWithOtel(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) client := policy.PooledClient() require.NotNil(t, client) require.NotNil(t, client.Transport) - transport, ok := client.Transport.(*http.Transport) + _, ok := client.Transport.(*otelhttp.Transport) require.True(t, ok) - require.NotNil(t, transport.DialContext) } func TestDefaultPolicyBlocksPrivateIPs(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() @@ -192,7 +192,7 @@ func TestDefaultPolicyBlocksPrivateIPs(t *testing.T) { func TestPolicy_DialerIPBlocking(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() @@ -216,7 +216,7 @@ func TestPolicy_DialerIPBlocking(t *testing.T) { func TestPolicy_DialerEdgeCases(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() @@ -262,7 +262,7 @@ func TestPolicy_DialerEdgeCases(t *testing.T) { func TestPolicy_DialerContext(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) @@ -290,7 +290,7 @@ func TestPolicy_HTTPClientWithCustomPolicy(t *testing.T) { require.NoError(t, err) // Create a custom policy that blocks the test server's IP - customPolicy, err := guardian.NewUnsafePolicy([]string{host + "/32"}) + customPolicy, err := guardian.NewUnsafePolicy(testenv.NewTracerProvider(t), []string{host + "/32"}) require.NoError(t, err) // Test that the custom policy blocks the server @@ -322,7 +322,7 @@ func TestPolicy_PooledHTTPClientWithFakeNetwork(t *testing.T) { require.NoError(t, err) // Create a custom policy that allows the test server but blocks private IPs - policy, err := guardian.NewUnsafePolicy([]string{ + policy, err := guardian.NewUnsafePolicy(testenv.NewTracerProvider(t), []string{ "192.168.0.0/16", // Block private IPs "10.0.0.0/8", // Block private IPs "172.16.0.0/12", // Block private IPs @@ -358,7 +358,7 @@ func TestPolicy_PooledHTTPClientWithFakeNetwork(t *testing.T) { func TestPolicy_IPv4MappedIPv6Addresses(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() @@ -440,7 +440,7 @@ func TestPolicy_IPv4MappedIPv6Addresses(t *testing.T) { func TestPolicy_IPv6VariationsBlocking(t *testing.T) { t.Parallel() - policy := guardian.NewDefaultPolicy() + policy := guardian.NewDefaultPolicy(testenv.NewTracerProvider(t)) dialer := policy.Dialer() ctx := t.Context() diff --git a/server/internal/hooks/setup_test.go b/server/internal/hooks/setup_test.go index 23223a41fe..fa33c6a932 100644 --- a/server/internal/hooks/setup_test.go +++ b/server/internal/hooks/setup_test.go @@ -14,6 +14,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -54,6 +55,8 @@ func newTestHooksService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := noop.NewTracerProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -63,7 +66,7 @@ func newTestHooksService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/keys/setup_test.go b/server/internal/keys/setup_test.go index 4d5021fc0f..a0965396ec 100644 --- a/server/internal/keys/setup_test.go +++ b/server/internal/keys/setup_test.go @@ -13,6 +13,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/keys" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -54,6 +55,8 @@ func newTestKeysService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -63,7 +66,7 @@ func newTestKeysService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/mcp/setup_test.go b/server/internal/mcp/setup_test.go index fee2287f75..c822d265d3 100644 --- a/server/internal/mcp/setup_test.go +++ b/server/internal/mcp/setup_test.go @@ -81,6 +81,8 @@ func newTestMCPService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := noop.NewMeterProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "mcptest") require.NoError(t, err) @@ -90,7 +92,7 @@ func newTestMCPService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) @@ -105,11 +107,10 @@ func newTestMCPService(t *testing.T) (context.Context, *testInstance) { env := environments.NewEnvironmentEntries(logger, conn, enc, mcpMetadataRepo) posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") cacheAdapter := cache.NewRedisCacheAdapter(redisClient) - guardianPolicy := guardian.NewDefaultPolicy() oauthService := oauth.NewService(logger, tracerProvider, meterProvider, conn, serverURL, cacheAdapter, enc, env, sessionManager) billingStub := billing.NewStubClient(logger, tracerProvider) devProvisioner := openrouter.NewDevelopment("test-openrouter-key") - chatClient := openrouter.NewUnifiedClient(logger, devProvisioner, nil, nil, nil, nil, nil) + chatClient := openrouter.NewUnifiedClient(logger, guardianPolicy, devProvisioner, nil, nil, nil, nil, nil) vectorToolStore := rag.NewToolsetVectorStore(logger, tracerProvider, conn, chatClient) chatSessions := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") featClient := productfeatures.NewClient(logger, tracerProvider, conn, redisClient) @@ -178,6 +179,8 @@ func newTestMCPServiceWithOAuth(t *testing.T, oauthSvc mcp.OAuthService) (contex logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := noop.NewMeterProvider() + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "mcptest") require.NoError(t, err) @@ -187,7 +190,7 @@ func newTestMCPServiceWithOAuth(t *testing.T, oauthSvc mcp.OAuthService) (contex billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) @@ -202,10 +205,9 @@ func newTestMCPServiceWithOAuth(t *testing.T, oauthSvc mcp.OAuthService) (contex env := environments.NewEnvironmentEntries(logger, conn, enc, mcpMetadataRepo) posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") cacheAdapter := cache.NewRedisCacheAdapter(redisClient) - guardianPolicy := guardian.NewDefaultPolicy() billingStub := billing.NewStubClient(logger, tracerProvider) devProvisioner := openrouter.NewDevelopment("test-openrouter-key") - chatClient := openrouter.NewUnifiedClient(logger, devProvisioner, nil, nil, nil, nil, nil) + chatClient := openrouter.NewUnifiedClient(logger, guardianPolicy, devProvisioner, nil, nil, nil, nil, nil) vectorToolStore := rag.NewToolsetVectorStore(logger, tracerProvider, conn, chatClient) featClient := productfeatures.NewClient(logger, tracerProvider, conn, redisClient) diff --git a/server/internal/mcpmetadata/setup_test.go b/server/internal/mcpmetadata/setup_test.go index abc23a9e9a..0bb63ff41d 100644 --- a/server/internal/mcpmetadata/setup_test.go +++ b/server/internal/mcpmetadata/setup_test.go @@ -14,6 +14,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/mcpmetadata" "github.com/speakeasy-api/gram/server/internal/testenv" toolsets_repo "github.com/speakeasy-api/gram/server/internal/toolsets/repo" @@ -60,6 +61,8 @@ func newTestMCPMetadataService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "mcpmetadatatest") require.NoError(t, err) @@ -70,7 +73,7 @@ func newTestMCPMetadataService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/oauth/external_oauth.go b/server/internal/oauth/external_oauth.go index 63e3087a61..7e6c0eb141 100644 --- a/server/internal/oauth/external_oauth.go +++ b/server/internal/oauth/external_oauth.go @@ -36,6 +36,7 @@ import ( deployments_repo "github.com/speakeasy-api/gram/server/internal/deployments/repo" "github.com/speakeasy-api/gram/server/internal/encryption" externalmcp_repo "github.com/speakeasy-api/gram/server/internal/externalmcp/repo" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oauth/repo" "github.com/speakeasy-api/gram/server/internal/oauth/wellknown" @@ -118,7 +119,7 @@ type ExternalOAuthService struct { allowedRedirectHosts []string auth *auth.Auth enc *encryption.Client - httpClient *http.Client + httpClient *guardian.HTTPClient successPageTmpl *template.Template successScriptHash string successScriptData []byte diff --git a/server/internal/productfeatures/setup_test.go b/server/internal/productfeatures/setup_test.go index 8565e0d40a..ffa384d720 100644 --- a/server/internal/productfeatures/setup_test.go +++ b/server/internal/productfeatures/setup_test.go @@ -12,6 +12,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/productfeatures" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -52,6 +53,8 @@ func newTestProductFeaturesService(t *testing.T) (context.Context, *testInstance logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -61,7 +64,7 @@ func newTestProductFeaturesService(t *testing.T) (context.Context, *testInstance billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/projects/setup_test.go b/server/internal/projects/setup_test.go index 594aa5492b..bde05ff914 100644 --- a/server/internal/projects/setup_test.go +++ b/server/internal/projects/setup_test.go @@ -14,6 +14,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/projects" "github.com/speakeasy-api/gram/server/internal/testenv" ) @@ -55,6 +56,8 @@ func newTestProjectsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -64,7 +67,7 @@ func newTestProjectsService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/telemetry/setup_test.go b/server/internal/telemetry/setup_test.go index 01c9059d50..9c94ecd98f 100644 --- a/server/internal/telemetry/setup_test.go +++ b/server/internal/telemetry/setup_test.go @@ -17,6 +17,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/telemetry" "github.com/speakeasy-api/gram/server/internal/telemetry/repo" @@ -64,6 +65,8 @@ func newTestLogsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -73,7 +76,7 @@ func newTestLogsService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-test"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-test"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") diff --git a/server/internal/templates/setup_test.go b/server/internal/templates/setup_test.go index 6a5ea076b3..9a13099be8 100644 --- a/server/internal/templates/setup_test.go +++ b/server/internal/templates/setup_test.go @@ -13,6 +13,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/templates" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/urn" @@ -54,6 +55,8 @@ func newTestTemplateService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -63,7 +66,7 @@ func newTestTemplateService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) diff --git a/server/internal/testenv/auth.go b/server/internal/testenv/auth.go index e8436df662..90164588cb 100644 --- a/server/internal/testenv/auth.go +++ b/server/internal/testenv/auth.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" mockidp "github.com/speakeasy-api/gram/mock-speakeasy-idp" "github.com/speakeasy-api/gram/server/internal/auth/sessions" @@ -19,6 +20,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/guardian" orgRepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" projectsRepo "github.com/speakeasy-api/gram/server/internal/projects/repo" "github.com/speakeasy-api/gram/server/internal/thirdparty/posthog" @@ -28,7 +30,7 @@ import ( // NewTestManager creates a sessions.Manager backed by a mock IDP httptest.Server. // This replaces the old NewUnsafeManager pattern - tests now exercise the real // auth code path (Speakeasy IDP HTTP calls) against a local mock. -func NewTestManager(t *testing.T, logger *slog.Logger, db *pgxpool.Pool, redisClient *redis.Client, suffix cache.Suffix, billingRepo billing.Repository) *sessions.Manager { +func NewTestManager(t *testing.T, logger *slog.Logger, tracerProvider trace.TracerProvider, guardianPolicy *guardian.Policy, db *pgxpool.Pool, redisClient *redis.Client, suffix cache.Suffix, billingRepo billing.Repository) *sessions.Manager { t.Helper() cfg := mockidp.NewConfig() @@ -42,7 +44,8 @@ func NewTestManager(t *testing.T, logger *slog.Logger, db *pgxpool.Pool, redisCl return sessions.NewManager( logger, - NewTracerProvider(t), + tracerProvider, + guardianPolicy, db, redisClient, suffix, diff --git a/server/internal/testenv/testing.go b/server/internal/testenv/testing.go index 7d2556a0fc..f91b60cbb4 100644 --- a/server/internal/testenv/testing.go +++ b/server/internal/testenv/testing.go @@ -17,6 +17,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/encryption" "github.com/speakeasy-api/gram/server/internal/externalmcp" "github.com/speakeasy-api/gram/server/internal/functions" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" ) @@ -87,9 +88,13 @@ func NewMCPRegistryClient(t *testing.T, logger *slog.Logger, tracerProvider trac pulseURL, err := url.Parse("https://api.pulsemcp.com") require.NoError(t, err, "expected pulse URL to parse") + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err, "expected guardian policy to initialize without error") + client := externalmcp.NewRegistryClient( NewLogger(t), tracerProvider, + guardianPolicy, externalmcp.NewPulseBackend(pulseURL, "test-tenant-id", conv.NewSecret([]byte("test-api-key"))), nil, ) diff --git a/server/internal/thirdparty/openrouter/openrouter.go b/server/internal/thirdparty/openrouter/openrouter.go index 22964ac031..4b683290af 100644 --- a/server/internal/thirdparty/openrouter/openrouter.go +++ b/server/internal/thirdparty/openrouter/openrouter.go @@ -13,11 +13,12 @@ import ( "slices" "time" - "github.com/hashicorp/go-retryablehttp" "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace" "github.com/speakeasy-api/gram/server/internal/attr" "github.com/speakeasy-api/gram/server/internal/billing" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/inv" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" @@ -93,19 +94,21 @@ type OpenRouter struct { logger *slog.Logger repo *repo.Queries orgRepo *orgRepo.Queries - orClient *http.Client + orClient *guardian.HTTPClient refresher KeyRefresher featureClient *productfeatures.Client } -func New(logger *slog.Logger, db *pgxpool.Pool, env string, provisioningKey string, refresher KeyRefresher, featureClient *productfeatures.Client, tracking billing.Tracker) *OpenRouter { +func New(logger *slog.Logger, tracerProvider trace.TracerProvider, guardianPolicy *guardian.Policy, db *pgxpool.Pool, env string, provisioningKey string, refresher KeyRefresher, featureClient *productfeatures.Client, tracking billing.Tracker) *OpenRouter { + orClient := guardianPolicy.PooledClient(guardian.WithRetryConfig(guardian.DefaultRetryConfig())) + return &OpenRouter{ provisioningKey: provisioningKey, env: env, logger: logger.With(attr.SlogComponent("openrouter")), repo: repo.New(db), orgRepo: orgRepo.New(db), - orClient: retryablehttp.NewClient().StandardClient(), + orClient: orClient, refresher: refresher, featureClient: featureClient, } diff --git a/server/internal/thirdparty/openrouter/unified_client.go b/server/internal/thirdparty/openrouter/unified_client.go index 0072ab97c6..4e808799d6 100644 --- a/server/internal/thirdparty/openrouter/unified_client.go +++ b/server/internal/thirdparty/openrouter/unified_client.go @@ -12,13 +12,13 @@ import ( "time" "github.com/google/uuid" - "github.com/hashicorp/go-cleanhttp" "go.opentelemetry.io/otel/trace" or_base "github.com/OpenRouterTeam/go-sdk" or "github.com/OpenRouterTeam/go-sdk/models/components" or_operations "github.com/OpenRouterTeam/go-sdk/models/operations" "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/telemetry" ) @@ -46,7 +46,7 @@ const ( // It applies pluggable strategies for message capture and usage tracking. type ChatClient struct { logger *slog.Logger - httpClient *http.Client + httpClient *guardian.HTTPClient provisioner Provisioner messageCaptureStrategy MessageCaptureStrategy usageTrackingStrategy UsageTrackingStrategy @@ -58,6 +58,7 @@ type ChatClient struct { // NewUnifiedClient creates a new UnifiedClient with the given strategies. func NewUnifiedClient( logger *slog.Logger, + guardianPolicy *guardian.Policy, provisioner Provisioner, captureStrategy MessageCaptureStrategy, trackingStrategy UsageTrackingStrategy, @@ -67,7 +68,7 @@ func NewUnifiedClient( ) *ChatClient { return &ChatClient{ logger: logger.With(attr.SlogComponent("openrouter_completions")), - httpClient: cleanhttp.DefaultPooledClient(), + httpClient: guardianPolicy.PooledClient(), provisioner: provisioner, messageCaptureStrategy: captureStrategy, usageTrackingStrategy: trackingStrategy, diff --git a/server/internal/thirdparty/openrouter/unified_client_test.go b/server/internal/thirdparty/openrouter/unified_client_test.go index 5e541b267f..b9db00ec64 100644 --- a/server/internal/thirdparty/openrouter/unified_client_test.go +++ b/server/internal/thirdparty/openrouter/unified_client_test.go @@ -15,7 +15,9 @@ import ( or "github.com/OpenRouterTeam/go-sdk/models/components" "github.com/google/uuid" "github.com/speakeasy-api/gram/server/internal/billing" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/telemetry" + "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -200,9 +202,14 @@ func TestChatClient_GetCompletion(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -319,9 +326,14 @@ func TestChatClient_GetCompletionStream(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -430,9 +442,14 @@ func TestChatClient_GetCompletion_WithToolCalls(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -528,9 +545,14 @@ func TestChatClient_ErrorHandling(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -552,7 +574,7 @@ func TestChatClient_ErrorHandling(t *testing.T) { } // Call GetCompletion - _, err := client.GetCompletion(context.Background(), req) + _, err = client.GetCompletion(context.Background(), req) require.Error(t, err) assert.Contains(t, err.Error(), tt.expectedError) }) @@ -595,9 +617,14 @@ func TestChatClient_MultipleCompletions_TitleAndResolutionScheduling(t *testing. resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -758,9 +785,14 @@ func TestChatClient_NilChatID_ShouldNotScheduleTitleGeneration(t *testing.T) { server := newMockServer(t) defer server.Close() + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + titleGenerator := &trackingTitleGenerator{} client := NewUnifiedClient( slog.Default(), + guardianPolicy, &mockProvisioner{apiKey: "test-api-key"}, &mockMessageCaptureStrategy{}, &mockUsageTrackingStrategy{}, @@ -780,7 +812,7 @@ func TestChatClient_NilChatID_ShouldNotScheduleTitleGeneration(t *testing.T) { APIKeyID: "", } - _, err := client.GetCompletion(context.Background(), req) + _, err = client.GetCompletion(context.Background(), req) require.NoError(t, err) time.Sleep(100 * time.Millisecond) @@ -799,10 +831,15 @@ func TestChatClient_TitleGeneration_ScheduledPerCompletionWithValidChatID(t *tes server := newMockServer(t) defer server.Close() + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + titleGenerator := &trackingTitleGenerator{} tracker := newTrackingCaptureStrategy() client := NewUnifiedClient( slog.Default(), + guardianPolicy, &mockProvisioner{apiKey: "test-api-key"}, tracker, &mockUsageTrackingStrategy{}, @@ -829,7 +866,7 @@ func TestChatClient_TitleGeneration_ScheduledPerCompletionWithValidChatID(t *tes } // One completion with nil ChatID (simulating title-gen activity's internal call) - _, err := client.GetCompletion(context.Background(), CompletionRequest{ + _, err = client.GetCompletion(context.Background(), CompletionRequest{ OrgID: "test-org", ProjectID: projectID.String(), Messages: []or.Message{CreateMessageUser("Generate title")}, @@ -861,9 +898,14 @@ func TestChatClient_ReloadChat_NoDuplicateMessages(t *testing.T) { server := newMockServer(t) defer server.Close() + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + tracker := newTrackingCaptureStrategy() client := NewUnifiedClient( slog.Default(), + guardianPolicy, &mockProvisioner{apiKey: "test-api-key"}, tracker, &mockUsageTrackingStrategy{}, @@ -885,7 +927,7 @@ func TestChatClient_ReloadChat_NoDuplicateMessages(t *testing.T) { UsageSource: billing.ModelUsageSourcePlayground, APIKeyID: "key-1", } - _, err := client.GetCompletion(context.Background(), req1) + _, err = client.GetCompletion(context.Background(), req1) require.NoError(t, err) // After round 1: DB should have [user(StartOrResumeChat), assistant(CaptureMessage)] @@ -986,9 +1028,14 @@ func TestChatClient_GetCompletion_WithJSONSchema(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, @@ -1096,9 +1143,14 @@ func TestChatClient_GetCompletion_WithoutJSONSchema(t *testing.T) { resolutionAnalyzer := &mockChatResolutionAnalyzer{} telemetryService := &mockTelemetryService{} + tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) + // Create client client := NewUnifiedClient( slog.Default(), + guardianPolicy, provisioner, captureStrategy, trackingStrategy, diff --git a/server/internal/thirdparty/polar/client.go b/server/internal/thirdparty/polar/client.go index 4ac399704c..690e499d35 100644 --- a/server/internal/thirdparty/polar/client.go +++ b/server/internal/thirdparty/polar/client.go @@ -27,6 +27,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/guardian" ) type Catalog struct { @@ -65,7 +66,7 @@ type Client struct { logger *slog.Logger tracer trace.Tracer polar *polargo.Polar - httpClient *http.Client + httpClient *guardian.HTTPClient bearerToken string catalog *Catalog customerStateCache cache.TypedCacheObject[PolarCustomerState] @@ -77,12 +78,15 @@ type Client struct { var _ billing.Tracker = (*Client)(nil) var _ billing.Repository = (*Client)(nil) -func NewClient(polarClient *polargo.Polar, bearerToken string, logger *slog.Logger, tracerProvider trace.TracerProvider, redisClient *redis.Client, catalog *Catalog, webhookSecret string) *Client { +func NewClient(guardianPolicy *guardian.Policy, polarClient *polargo.Polar, bearerToken string, logger *slog.Logger, tracerProvider trace.TracerProvider, redisClient *redis.Client, catalog *Catalog, webhookSecret string) *Client { + client := guardianPolicy.PooledClient() + client.Timeout = 30 * time.Second + return &Client{ logger: logger.With(attr.SlogComponent("polar_usage")), tracer: tracerProvider.Tracer("github.com/speakeasy-api/gram/server/internal/thirdparty/polar"), polar: polarClient, - httpClient: &http.Client{Timeout: 30 * time.Second}, + httpClient: client, bearerToken: bearerToken, catalog: catalog, customerStateCache: cache.NewTypedObjectCache[PolarCustomerState](logger.With(attr.SlogCacheNamespace("polar-customer-state")), cache.NewRedisCacheAdapter(redisClient), cache.SuffixNone), diff --git a/server/internal/thirdparty/slack/client/client.go b/server/internal/thirdparty/slack/client/client.go index 7e487cdeea..7768f00101 100644 --- a/server/internal/thirdparty/slack/client/client.go +++ b/server/internal/thirdparty/slack/client/client.go @@ -10,10 +10,10 @@ import ( "strings" "github.com/google/uuid" - "github.com/hashicorp/go-cleanhttp" "github.com/jackc/pgx/v5/pgxpool" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/encryption" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/thirdparty/slack/repo" ) @@ -22,18 +22,18 @@ const slackServer = "https://slack.com/api" type SlackClient struct { clientID string clientSecret string - client *http.Client + client *guardian.HTTPClient enc *encryption.Client repo *repo.Queries enabled bool } -func NewSlackClient(clientID, clientSecret string, db *pgxpool.Pool, enc *encryption.Client) *SlackClient { +func NewSlackClient(guardianPolicy *guardian.Policy, clientID, clientSecret string, db *pgxpool.Pool, enc *encryption.Client) *SlackClient { enabled := clientID != "" && clientSecret != "" return &SlackClient{ clientID: clientID, clientSecret: clientSecret, - client: cleanhttp.DefaultPooledClient(), + client: guardianPolicy.PooledClient(), enc: enc, repo: repo.New(db), enabled: enabled, diff --git a/server/internal/thirdparty/workos/roles.go b/server/internal/thirdparty/workos/roles.go index 4688c9d960..f9b6453ab9 100644 --- a/server/internal/thirdparty/workos/roles.go +++ b/server/internal/thirdparty/workos/roles.go @@ -15,6 +15,7 @@ import ( "github.com/workos/workos-go/v6/pkg/organizations" "github.com/workos/workos-go/v6/pkg/usermanagement" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" ) @@ -68,7 +69,7 @@ func (e *APIError) Error() string { type RoleClient struct { apiKey string endpoint string // base URL for raw HTTP calls; defaults to workosBaseURL - httpClient *http.Client + httpClient *guardian.HTTPClient orgs *organizations.Client um *usermanagement.Client } @@ -79,7 +80,7 @@ type RoleClientOpts struct { // Endpoint overrides the WorkOS base URL for both raw HTTP and SDK calls. Endpoint string // HTTPClient overrides the default retryable HTTP client. - HTTPClient *http.Client + HTTPClient *guardian.HTTPClient } func NewRoleClient(apiKey string, opts ...RoleClientOpts) *RoleClient { diff --git a/server/internal/tools/setup_test.go b/server/internal/tools/setup_test.go index 5938580155..734b8079b1 100644 --- a/server/internal/tools/setup_test.go +++ b/server/internal/tools/setup_test.go @@ -23,6 +23,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/deployments" "github.com/speakeasy-api/gram/server/internal/feature" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" packages "github.com/speakeasy-api/gram/server/internal/packages" "github.com/speakeasy-api/gram/server/internal/templates" @@ -74,6 +75,8 @@ func newTestToolsService(t *testing.T, assetStorage assets.BlobStore) (context.C logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -84,7 +87,7 @@ func newTestToolsService(t *testing.T, assetStorage assets.BlobStore) (context.C billingClient := billing.NewStubClient(logger, tracerProvider) posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -105,7 +108,7 @@ func newTestToolsService(t *testing.T, assetStorage assets.BlobStore) (context.C toolsSvc := tools.NewService(logger, tracerProvider, conn, sessionManager) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient) - assetsSvc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") + assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") packagesSvc := packages.NewService(logger, tracerProvider, conn, sessionManager) toolsetsSvc := toolsets.NewService(logger, tracerProvider, conn, sessionManager, cache.NewRedisCacheAdapter(redisClient)) templatesSvc := templates.NewService(logger, tracerProvider, conn, sessionManager, toolsetsSvc) diff --git a/server/internal/toolsets/setup_test.go b/server/internal/toolsets/setup_test.go index c932140991..c2505cffb4 100644 --- a/server/internal/toolsets/setup_test.go +++ b/server/internal/toolsets/setup_test.go @@ -28,6 +28,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/deployments" "github.com/speakeasy-api/gram/server/internal/feature" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/o11y" packages "github.com/speakeasy-api/gram/server/internal/packages" "github.com/speakeasy-api/gram/server/internal/temporal" @@ -79,6 +80,8 @@ func newTestToolsetsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) meterProvider := testenv.NewMeterProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -106,7 +109,7 @@ func newTestToolsetsService(t *testing.T) (context.Context, *testInstance) { posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) chatSessionsManager := chatsessions.NewManager(logger, redisClient, "test-jwt-secret") @@ -114,7 +117,7 @@ func newTestToolsetsService(t *testing.T) (context.Context, *testInstance) { svc := toolsets.NewService(logger, tracerProvider, conn, sessionManager, nil) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient) - assetsSvc := assets.NewService(logger, tracerProvider, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") + assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret") packagesSvc := packages.NewService(logger, tracerProvider, conn, sessionManager) return ctx, &testInstance{ diff --git a/server/internal/variations/setup_test.go b/server/internal/variations/setup_test.go index ad2b692b22..f9639d3f57 100644 --- a/server/internal/variations/setup_test.go +++ b/server/internal/variations/setup_test.go @@ -10,6 +10,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/auth/sessions" "github.com/speakeasy-api/gram/server/internal/billing" "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/variations" "github.com/stretchr/testify/require" @@ -51,6 +52,8 @@ func newTestVariationsService(t *testing.T) (context.Context, *testInstance) { logger := testenv.NewLogger(t) tracerProvider := testenv.NewTracerProvider(t) + guardianPolicy, err := guardian.NewUnsafePolicy(tracerProvider, []string{}) + require.NoError(t, err) conn, err := infra.CloneTestDatabase(t, "testdb") require.NoError(t, err) @@ -60,7 +63,7 @@ func newTestVariationsService(t *testing.T) (context.Context, *testInstance) { billingClient := billing.NewStubClient(logger, tracerProvider) - sessionManager := testenv.NewTestManager(t, logger, conn, redisClient, cache.Suffix("gram-local"), billingClient) + sessionManager := testenv.NewTestManager(t, logger, tracerProvider, guardianPolicy, conn, redisClient, cache.Suffix("gram-local"), billingClient) ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager)