Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions cmd/internal/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2288,3 +2288,83 @@ tools:
})
}
}

func TestParseToolFileCheckAtStartup(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

testCases := []struct {
name string
in []byte
want bool
}{
{
name: "explicit false",
in: testutils.FormatYaml(`
kind: sources
name: http-source
type: http
baseUrl: https://example.com
checkAtStartup: false
`),
want: false,
},
{
name: "explicit true",
in: testutils.FormatYaml(`
kind: sources
name: http-source
type: http
baseUrl: https://example.com
checkAtStartup: true
`),
want: true,
},
{
name: "default true",
in: testutils.FormatYaml(`
kind: sources
name: http-source-default
type: http
baseUrl: https://example.com
`),
want: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
toolsFile, err := parseToolsFile(ctx, tc.in)
if err != nil {
t.Fatalf("failed to parse input: %v", err)
}
cfg := toolsFile.Sources["http-source"]
if cfg == nil {
cfg = toolsFile.Sources["http-source-default"]
}
if cfg == nil {
t.Fatalf("expected source to be parsed")
}
if sources.CheckAtStartup(cfg) != tc.want {
t.Fatalf("expected checkAtStartup=%v", tc.want)
}
})
}

invalidInput := testutils.FormatYaml(`
kind: sources
name: http-source-invalid
type: http
baseUrl: https://example.com
checkAtStartup: "nope"
`)
_, err = parseToolsFile(ctx, invalidInput)
if err == nil {
t.Fatalf("expected error for invalid checkAtStartup type")
}
if !strings.Contains(err.Error(), "checkAtStartup must be a boolean") {
t.Fatalf("unexpected error: %v", err)
}
}
1 change: 1 addition & 0 deletions cmd/internal/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) {
persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')")
persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
persistentFlags.BoolVar(&opts.Cfg.AllowPartialSources, "allow-partial-sources", false, "Skip startup connectivity checks for sources to allow partial availability.")
}

// ConfigFileFlags defines flags related to the configuration file.
Expand Down
13 changes: 7 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
return cmd
}

func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *server.Server) error {
func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *server.Server, allowPartialSources bool) error {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
panic(err)
}

sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile, allowPartialSources)
if err != nil {
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
logger.WarnContext(ctx, errMsg.Error())
Expand All @@ -151,7 +151,7 @@ func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *serv

// validateReloadEdits checks that the reloaded config configs can initialized without failing
func validateReloadEdits(
ctx context.Context, toolsFile internal.Config,
ctx context.Context, toolsFile internal.Config, allowPartialSources bool,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) {
logger, err := util.LoggerFromContext(ctx)
Expand All @@ -177,6 +177,7 @@ func validateReloadEdits(
ToolConfigs: toolsFile.Tools,
ToolsetConfigs: toolsFile.Toolsets,
PromptConfigs: toolsFile.Prompts,
AllowPartialSources: allowPartialSources,
}

sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
Expand Down Expand Up @@ -232,7 +233,7 @@ func scanWatchedFiles(watchingFolder bool, folderToWatch string, watchedFiles ma
}

// watchChanges checks for changes in the provided yaml config(s) or folder.
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int) {
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int, allowPartialSources bool) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
panic(err)
Expand Down Expand Up @@ -378,7 +379,7 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m
continue
}

err = handleDynamicReload(ctx, reloadedConfig, s)
err = handleDynamicReload(ctx, reloadedConfig, s, allowPartialSources)
if err != nil {
errMsg := fmt.Errorf("unable to parse reloaded config at %q: %w", reloadedConfig, err)
logger.WarnContext(ctx, errMsg.Error())
Expand Down Expand Up @@ -509,7 +510,7 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error {
if isCustomConfigured && !opts.Cfg.DisableReload {
watchDirs, watchedFiles := resolveWatcherInputs(opts.Config, opts.Configs, opts.ConfigFolder)
// start watching the file(s) or folder for changes to trigger dynamic reloading
go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval)
go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval, opts.Cfg.AllowPartialSources)
}

// wait for either the server to error out or the command's context to be canceled
Expand Down
5 changes: 5 additions & 0 deletions docs/en/documentation/configuration/sources/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ instance: my-instance-name
database: my_db
user: ${USER_NAME}
password: ${PASSWORD}
# Optional: skip startup connectivity checks for this source
checkAtStartup: false
```

In implementation, each source is a different connection pool or client that used
to connect to the database and execute the tool.

By default, toolbox validates source connectivity at startup. Set `checkAtStartup: false`
to allow a specific source to be offline without blocking startup.

## Available Sources

To see all supported sources and the specific tools they unlock, explore the full list of our [Integrations](../../../integrations/_index.md).
1 change: 1 addition & 0 deletions docs/en/reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ description: >
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
| | `--user-agent-metadata` | Appends additional metadata to the User-Agent. | |
| | `--allow-partial-sources` | Skip startup connectivity checks for sources to allow partial availability. | |
| | `--poll-interval` | Specifies the polling frequency (seconds) for configuration file updates. | `0` |
| `-v` | `--version` | version for toolbox | |

Expand Down
16 changes: 16 additions & 0 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type ServerConfig struct {
Port int
// SourceConfigs defines what sources of data are available for tools.
SourceConfigs SourceConfigs
// AllowPartialSources allows startup when some sources are unavailable by skipping connectivity checks.
AllowPartialSources bool
// AuthServiceConfigs defines what sources of authentication are available for tools.
AuthServiceConfigs AuthServiceConfigs
// EmbeddingModelConfigs defines a models used to embed parameters.
Expand Down Expand Up @@ -244,6 +246,17 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}

checkAtStartup := true
if raw, ok := r["checkAtStartup"]; ok {
val, ok := raw.(bool)
if !ok {
return nil, fmt.Errorf("checkAtStartup must be a boolean for source %q", name)
}
checkAtStartup = val
delete(r, "checkAtStartup")
}

dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %w", err)
Expand All @@ -252,6 +265,9 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an
if err != nil {
return nil, err
}
if !checkAtStartup {
sourceConfig = sources.WrapSourceConfigWithStartupCheck(sourceConfig, checkAtStartup)
}
return sourceConfig, nil
}

Expand Down
11 changes: 11 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (

// initialize and validate the sources from configs
sourcesMap := make(map[string]sources.Source)
if cfg.AllowPartialSources {
l.WarnContext(ctx, "Skipping startup connectivity checks for all sources (allow-partial-sources enabled).")
}
for name, sc := range cfg.SourceConfigs {
checkAtStartup := sources.CheckAtStartup(sc)
if cfg.AllowPartialSources {
checkAtStartup = false
}
s, err := func() (sources.Source, error) {
childCtx, span := instrumentation.Tracer.Start(
ctx,
Expand All @@ -97,6 +104,10 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
trace.WithAttributes(attribute.String("source_name", name)),
)
defer span.End()
if !checkAtStartup && !cfg.AllowPartialSources {
l.InfoContext(childCtx, fmt.Sprintf("Skipping startup connectivity check for source %q", name))
}
childCtx = sources.WithStartupCheck(childCtx, checkAtStartup)
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
if err != nil {
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
Expand Down
5 changes: 3 additions & 2 deletions internal/sources/alloydbpg/alloydb_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.Ping(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.Ping(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/cloudsqlmssql/cloud_sql_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
}

// Verify db connection
err = db.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return db.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/cloudsqlmysql/cloud_sql_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
4 changes: 3 additions & 1 deletion internal/sources/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tr
for attempt := 0; attempt <= maxRetries; attempt++ {
pool, err = pgxpool.New(ctx, connURL.String())
if err == nil {
err = pool.Ping(ctx)
err = sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.Ping(ctx)
})
}

if err == nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/sources/firebird/firebird.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/mindsdb/mindsdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
}

// Verify the connection
err = client.Ping(ctx, nil)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return client.Ping(ctx, nil)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
}

// Verify db connection
err = db.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return db.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/oceanbase/oceanbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create Oracle connection: %w", err)
}

err = db.PingContext(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return db.PingContext(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect to Oracle successfully: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/sources/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("unable to create pool: %w", err)
}

err = pool.Ping(ctx)
if err != nil {
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
return pool.Ping(ctx)
}); err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}

Expand Down
Loading