diff --git a/cmd/internal/config_test.go b/cmd/internal/config_test.go index d9bbeaa53e67..a7a4a36b5a9b 100644 --- a/cmd/internal/config_test.go +++ b/cmd/internal/config_test.go @@ -2288,3 +2288,83 @@ tools: }) } } + +func TestParseToolFileCheckAtStartup(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + testCases := []struct { + name string + in []byte + want bool + }{ + { + name: "explicit false", + in: testutils.FormatYaml(` + kind: sources + name: http-source + type: http + baseUrl: https://example.com + checkAtStartup: false + `), + want: false, + }, + { + name: "explicit true", + in: testutils.FormatYaml(` + kind: sources + name: http-source + type: http + baseUrl: https://example.com + checkAtStartup: true + `), + want: true, + }, + { + name: "default true", + in: testutils.FormatYaml(` + kind: sources + name: http-source-default + type: http + baseUrl: https://example.com + `), + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, tc.in) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + cfg := toolsFile.Sources["http-source"] + if cfg == nil { + cfg = toolsFile.Sources["http-source-default"] + } + if cfg == nil { + t.Fatalf("expected source to be parsed") + } + if sources.CheckAtStartup(cfg) != tc.want { + t.Fatalf("expected checkAtStartup=%v", tc.want) + } + }) + } + + invalidInput := testutils.FormatYaml(` + kind: sources + name: http-source-invalid + type: http + baseUrl: https://example.com + checkAtStartup: "nope" + `) + _, err = parseToolsFile(ctx, invalidInput) + if err == nil { + t.Fatalf("expected error for invalid checkAtStartup type") + } + if !strings.Contains(err.Error(), "checkAtStartup must be a boolean") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/cmd/internal/flags.go b/cmd/internal/flags.go index 0ccc14ae28e6..bf41b2cd6849 100644 --- a/cmd/internal/flags.go +++ b/cmd/internal/flags.go @@ -35,6 +35,7 @@ func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) { persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") + persistentFlags.BoolVar(&opts.Cfg.AllowPartialSources, "allow-partial-sources", false, "Skip startup connectivity checks for sources to allow partial availability.") } // ConfigFileFlags defines flags related to the configuration file. diff --git a/cmd/root.go b/cmd/root.go index ea225117b4cd..d70cea011253 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -131,13 +131,13 @@ func NewCommand(opts *internal.ToolboxOptions) *cobra.Command { return cmd } -func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *server.Server) error { +func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *server.Server, allowPartialSources bool) error { logger, err := util.LoggerFromContext(ctx) if err != nil { panic(err) } - sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile, allowPartialSources) if err != nil { errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err) logger.WarnContext(ctx, errMsg.Error()) @@ -151,7 +151,7 @@ func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *serv // validateReloadEdits checks that the reloaded config configs can initialized without failing func validateReloadEdits( - ctx context.Context, toolsFile internal.Config, + ctx context.Context, toolsFile internal.Config, allowPartialSources bool, ) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error, ) { logger, err := util.LoggerFromContext(ctx) @@ -177,6 +177,7 @@ func validateReloadEdits( ToolConfigs: toolsFile.Tools, ToolsetConfigs: toolsFile.Toolsets, PromptConfigs: toolsFile.Prompts, + AllowPartialSources: allowPartialSources, } sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig) @@ -232,7 +233,7 @@ func scanWatchedFiles(watchingFolder bool, folderToWatch string, watchedFiles ma } // watchChanges checks for changes in the provided yaml config(s) or folder. -func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int) { +func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int, allowPartialSources bool) { logger, err := util.LoggerFromContext(ctx) if err != nil { panic(err) @@ -378,7 +379,7 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m continue } - err = handleDynamicReload(ctx, reloadedConfig, s) + err = handleDynamicReload(ctx, reloadedConfig, s, allowPartialSources) if err != nil { errMsg := fmt.Errorf("unable to parse reloaded config at %q: %w", reloadedConfig, err) logger.WarnContext(ctx, errMsg.Error()) @@ -509,7 +510,7 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error { if isCustomConfigured && !opts.Cfg.DisableReload { watchDirs, watchedFiles := resolveWatcherInputs(opts.Config, opts.Configs, opts.ConfigFolder) // start watching the file(s) or folder for changes to trigger dynamic reloading - go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval) + go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval, opts.Cfg.AllowPartialSources) } // wait for either the server to error out or the command's context to be canceled diff --git a/docs/en/documentation/configuration/sources/_index.md b/docs/en/documentation/configuration/sources/_index.md index 4d015a8890eb..3f995d74448c 100644 --- a/docs/en/documentation/configuration/sources/_index.md +++ b/docs/en/documentation/configuration/sources/_index.md @@ -26,11 +26,16 @@ instance: my-instance-name database: my_db user: ${USER_NAME} password: ${PASSWORD} +# Optional: skip startup connectivity checks for this source +checkAtStartup: false ``` In implementation, each source is a different connection pool or client that used to connect to the database and execute the tool. +By default, toolbox validates source connectivity at startup. Set `checkAtStartup: false` +to allow a specific source to be offline without blocking startup. + ## Available Sources To see all supported sources and the specific tools they unlock, explore the full list of our [Integrations](../../../integrations/_index.md). diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 4a761bc634b7..6279960b04a0 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -29,6 +29,7 @@ description: > | | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` | | | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` | | | `--user-agent-metadata` | Appends additional metadata to the User-Agent. | | +| | `--allow-partial-sources` | Skip startup connectivity checks for sources to allow partial availability. | | | | `--poll-interval` | Specifies the polling frequency (seconds) for configuration file updates. | `0` | | `-v` | `--version` | version for toolbox | | diff --git a/internal/server/config.go b/internal/server/config.go index efa947cefeac..206af0d09905 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -42,6 +42,8 @@ type ServerConfig struct { Port int // SourceConfigs defines what sources of data are available for tools. SourceConfigs SourceConfigs + // AllowPartialSources allows startup when some sources are unavailable by skipping connectivity checks. + AllowPartialSources bool // AuthServiceConfigs defines what sources of authentication are available for tools. AuthServiceConfigs AuthServiceConfigs // EmbeddingModelConfigs defines a models used to embed parameters. @@ -244,6 +246,17 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an if !ok { return nil, fmt.Errorf("missing 'type' field or it is not a string") } + + checkAtStartup := true + if raw, ok := r["checkAtStartup"]; ok { + val, ok := raw.(bool) + if !ok { + return nil, fmt.Errorf("checkAtStartup must be a boolean for source %q", name) + } + checkAtStartup = val + delete(r, "checkAtStartup") + } + dec, err := util.NewStrictDecoder(r) if err != nil { return nil, fmt.Errorf("error creating decoder: %w", err) @@ -252,6 +265,9 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an if err != nil { return nil, err } + if !checkAtStartup { + sourceConfig = sources.WrapSourceConfigWithStartupCheck(sourceConfig, checkAtStartup) + } return sourceConfig, nil } diff --git a/internal/server/server.go b/internal/server/server.go index f0ccad606515..4bf104e63e77 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -88,7 +88,14 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( // initialize and validate the sources from configs sourcesMap := make(map[string]sources.Source) + if cfg.AllowPartialSources { + l.WarnContext(ctx, "Skipping startup connectivity checks for all sources (allow-partial-sources enabled).") + } for name, sc := range cfg.SourceConfigs { + checkAtStartup := sources.CheckAtStartup(sc) + if cfg.AllowPartialSources { + checkAtStartup = false + } s, err := func() (sources.Source, error) { childCtx, span := instrumentation.Tracer.Start( ctx, @@ -97,6 +104,10 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( trace.WithAttributes(attribute.String("source_name", name)), ) defer span.End() + if !checkAtStartup && !cfg.AllowPartialSources { + l.InfoContext(childCtx, fmt.Sprintf("Skipping startup connectivity check for source %q", name)) + } + childCtx = sources.WithStartupCheck(childCtx, checkAtStartup) s, err := sc.Initialize(childCtx, instrumentation.Tracer) if err != nil { return nil, fmt.Errorf("unable to initialize source %q: %w", name, err) diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index c7617ae824d7..9febc4671ded 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -71,8 +71,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.Ping(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.Ping(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/clickhouse/clickhouse.go b/internal/sources/clickhouse/clickhouse.go index 7389c4814ab9..0e04c2e5cf69 100644 --- a/internal/sources/clickhouse/clickhouse.go +++ b/internal/sources/clickhouse/clickhouse.go @@ -69,8 +69,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index 8a3605265acf..ed44e4b47113 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -74,8 +74,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So } // Verify db connection - err = db.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return db.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index cce65db37497..19bd8d93fed9 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -71,8 +71,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/cockroachdb/cockroachdb.go b/internal/sources/cockroachdb/cockroachdb.go index 0838b13e7401..3a324207dd7e 100644 --- a/internal/sources/cockroachdb/cockroachdb.go +++ b/internal/sources/cockroachdb/cockroachdb.go @@ -396,7 +396,9 @@ func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tr for attempt := 0; attempt <= maxRetries; attempt++ { pool, err = pgxpool.New(ctx, connURL.String()) if err == nil { - err = pool.Ping(ctx) + err = sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.Ping(ctx) + }) } if err == nil { diff --git a/internal/sources/firebird/firebird.go b/internal/sources/firebird/firebird.go index 9429e3d0ce01..c763f25c1441 100644 --- a/internal/sources/firebird/firebird.go +++ b/internal/sources/firebird/firebird.go @@ -65,8 +65,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/mindsdb/mindsdb.go b/internal/sources/mindsdb/mindsdb.go index f0abefb972ce..d82a29819227 100644 --- a/internal/sources/mindsdb/mindsdb.go +++ b/internal/sources/mindsdb/mindsdb.go @@ -67,8 +67,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/mongodb/mongodb.go b/internal/sources/mongodb/mongodb.go index 485718e081af..07bae23f6b90 100644 --- a/internal/sources/mongodb/mongodb.go +++ b/internal/sources/mongodb/mongodb.go @@ -65,8 +65,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So } // Verify the connection - err = client.Ping(ctx, nil) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return client.Ping(ctx, nil) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/mssql/mssql.go b/internal/sources/mssql/mssql.go index 06c33af691fb..bc1d725d874c 100644 --- a/internal/sources/mssql/mssql.go +++ b/internal/sources/mssql/mssql.go @@ -72,8 +72,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So } // Verify db connection - err = db.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return db.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/mysql/mysql.go b/internal/sources/mysql/mysql.go index 477c9a982826..6b5c07f87e69 100644 --- a/internal/sources/mysql/mysql.go +++ b/internal/sources/mysql/mysql.go @@ -70,8 +70,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/oceanbase/oceanbase.go b/internal/sources/oceanbase/oceanbase.go index fbeb03d42524..46bc110c6a92 100644 --- a/internal/sources/oceanbase/oceanbase.go +++ b/internal/sources/oceanbase/oceanbase.go @@ -67,8 +67,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 30d1332a7fb5..eb8534d4765a 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -106,8 +106,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create Oracle connection: %w", err) } - err = db.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return db.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect to Oracle successfully: %w", err) } diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index 86c7235ac53f..10306ecac9cb 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -70,8 +70,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.Ping(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.Ping(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/redis/redis.go b/internal/sources/redis/redis.go index 2ab63fb2906b..45340901fe8d 100644 --- a/internal/sources/redis/redis.go +++ b/internal/sources/redis/redis.go @@ -106,7 +106,6 @@ func initRedisClient(ctx context.Context, r Config) (RedisClient, error) { } var client RedisClient - var err error if r.ClusterEnabled { // Create a new Redis Cluster client clusterClient := redis.NewClusterClient(&redis.ClusterOptions{ @@ -120,11 +119,12 @@ func initRedisClient(ctx context.Context, r Config) (RedisClient, error) { Password: r.Password, TLSConfig: tlsConfig, }) - err = clusterClient.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { - return shard.Ping(ctx).Err() - }) - if err != nil { - return nil, fmt.Errorf("unable to connect to redis cluster: %s", err) + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return clusterClient.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { + return shard.Ping(ctx).Err() + }) + }); err != nil { + return nil, fmt.Errorf("unable to connect to redis cluster: %w", err) } client = clusterClient return client, nil @@ -142,9 +142,11 @@ func initRedisClient(ctx context.Context, r Config) (RedisClient, error) { Password: r.Password, TLSConfig: tlsConfig, }) - _, err = standaloneClient.Ping(ctx).Result() - if err != nil { - return nil, fmt.Errorf("unable to connect to redis: %s", err) + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + _, err := standaloneClient.Ping(ctx).Result() + return err + }); err != nil { + return nil, fmt.Errorf("unable to connect to redis: %w", err) } client = standaloneClient return client, nil diff --git a/internal/sources/singlestore/singlestore.go b/internal/sources/singlestore/singlestore.go index 971808d908d6..9395826412ef 100644 --- a/internal/sources/singlestore/singlestore.go +++ b/internal/sources/singlestore/singlestore.go @@ -73,8 +73,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/snowflake/snowflake.go b/internal/sources/snowflake/snowflake.go index 408bbc1a3759..15f9212f1b3b 100644 --- a/internal/sources/snowflake/snowflake.go +++ b/internal/sources/snowflake/snowflake.go @@ -66,8 +66,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create connection: %w", err) } - err = db.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return db.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/sources.go b/internal/sources/sources.go index 93ae0233b127..8ca705798665 100644 --- a/internal/sources/sources.go +++ b/internal/sources/sources.go @@ -59,12 +59,68 @@ type SourceConfig interface { Initialize(ctx context.Context, tracer trace.Tracer) (Source, error) } +// StartupCheckConfig is an optional interface for controlling connectivity checks on startup. +type StartupCheckConfig interface { + CheckAtStartup() bool +} + +type startupCheckWrapper struct { + SourceConfig + checkAtStartup bool +} + +func (w startupCheckWrapper) CheckAtStartup() bool { + return w.checkAtStartup +} + +// WrapSourceConfigWithStartupCheck wraps a SourceConfig to provide a startup check override. +func WrapSourceConfigWithStartupCheck(cfg SourceConfig, checkAtStartup bool) SourceConfig { + if checkAtStartup { + return cfg + } + return startupCheckWrapper{SourceConfig: cfg, checkAtStartup: checkAtStartup} +} + +// CheckAtStartup returns whether the config should perform connectivity checks at startup. +func CheckAtStartup(cfg SourceConfig) bool { + if cfgWithCheck, ok := cfg.(StartupCheckConfig); ok { + return cfgWithCheck.CheckAtStartup() + } + return true +} + // Source is the interface for the source itself. type Source interface { SourceType() string ToConfig() SourceConfig } +type startupCheckKey struct{} + +// WithStartupCheck returns a context that controls whether startup checks run for a source. +func WithStartupCheck(ctx context.Context, checkAtStartup bool) context.Context { + return context.WithValue(ctx, startupCheckKey{}, checkAtStartup) +} + +// ShouldCheckAtStartup indicates whether connectivity checks should run for a source init. +func ShouldCheckAtStartup(ctx context.Context) bool { + if ctx == nil { + return true + } + if val, ok := ctx.Value(startupCheckKey{}).(bool); ok { + return val + } + return true +} + +// CheckConnectivity runs the provided check function only when startup checks are enabled. +func CheckConnectivity(ctx context.Context, checkFn func(context.Context) error) error { + if !ShouldCheckAtStartup(ctx) { + return nil + } + return checkFn(ctx) +} + // InitConnectionSpan adds a span for database pool connection initialization func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceType, sourceName string) (context.Context, trace.Span) { ctx, span := tracer.Start( diff --git a/internal/sources/sqlite/sqlite.go b/internal/sources/sqlite/sqlite.go index 8c15ceed45b8..c7019a871dd8 100644 --- a/internal/sources/sqlite/sqlite.go +++ b/internal/sources/sqlite/sqlite.go @@ -62,8 +62,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create db connection: %w", err) } - err = db.PingContext(context.Background()) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return db.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/tidb/tidb.go b/internal/sources/tidb/tidb.go index 3ab2b5cc2dc4..62ffde48c415 100644 --- a/internal/sources/tidb/tidb.go +++ b/internal/sources/tidb/tidb.go @@ -74,8 +74,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/trino/trino.go b/internal/sources/trino/trino.go index 722bc2d214cf..788b20783986 100644 --- a/internal/sources/trino/trino.go +++ b/internal/sources/trino/trino.go @@ -77,8 +77,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.PingContext(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.PingContext(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) } diff --git a/internal/sources/valkey/valkey.go b/internal/sources/valkey/valkey.go index 9b80568914ea..80a94d255879 100644 --- a/internal/sources/valkey/valkey.go +++ b/internal/sources/valkey/valkey.go @@ -99,10 +99,12 @@ func initValkeyClient(ctx context.Context, r Config) (valkey.Client, error) { } // Ping the server to check connectivity - pingCmd := client.B().Ping().Build() - _, err = client.Do(ctx, pingCmd).ToString() - if err != nil { - log.Fatalf("Failed to execute PING command: %v", err) + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + pingCmd := client.B().Ping().Build() + _, err := client.Do(ctx, pingCmd).ToString() + return err + }); err != nil { + return nil, fmt.Errorf("failed to execute PING command: %w", err) } return client, nil } diff --git a/internal/sources/yugabytedb/yugabytedb.go b/internal/sources/yugabytedb/yugabytedb.go index 10d1471723a8..09af0bc63520 100644 --- a/internal/sources/yugabytedb/yugabytedb.go +++ b/internal/sources/yugabytedb/yugabytedb.go @@ -68,8 +68,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("unable to create pool: %w", err) } - err = pool.Ping(ctx) - if err != nil { + if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error { + return pool.Ping(ctx) + }); err != nil { return nil, fmt.Errorf("unable to connect successfully: %w", err) }