Skip to content

Commit 6541422

Browse files
committed
feat(server): allow partial source startup checks
1 parent 80ba0a3 commit 6541422

File tree

30 files changed

+216
-58
lines changed

30 files changed

+216
-58
lines changed

cmd/internal/persistent_flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) {
4343
)
4444
persistentFlags.StringSliceVar(&opts.PrebuiltConfigs, "prebuilt", []string{}, prebuiltHelp)
4545
persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
46+
persistentFlags.BoolVar(&opts.Cfg.AllowPartialSources, "allow-partial-sources", false, "Skip startup connectivity checks for sources to allow partial availability.")
4647
}

cmd/internal/tools_file_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/googleapis/genai-toolbox/internal/prompts"
2727
"github.com/googleapis/genai-toolbox/internal/prompts/custom"
2828
"github.com/googleapis/genai-toolbox/internal/server"
29+
"github.com/googleapis/genai-toolbox/internal/sources"
2930
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
3031
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
3132
"github.com/googleapis/genai-toolbox/internal/testutils"
@@ -2164,3 +2165,44 @@ tools:
21642165
})
21652166
}
21662167
}
2168+
2169+
func TestParseToolFileCheckAtStartup(t *testing.T) {
2170+
ctx, err := testutils.ContextWithNewLogger()
2171+
if err != nil {
2172+
t.Fatalf("unexpected error: %s", err)
2173+
}
2174+
input := testutils.FormatYaml(`
2175+
kind: sources
2176+
name: http-source
2177+
type: http
2178+
baseUrl: https://example.com
2179+
checkAtStartup: false
2180+
`)
2181+
2182+
toolsFile, err := parseToolsFile(ctx, input)
2183+
if err != nil {
2184+
t.Fatalf("failed to parse input: %v", err)
2185+
}
2186+
cfg, ok := toolsFile.Sources["http-source"]
2187+
if !ok {
2188+
t.Fatalf("expected source to be parsed")
2189+
}
2190+
if sources.CheckAtStartup(cfg) {
2191+
t.Fatalf("expected checkAtStartup to be false")
2192+
}
2193+
2194+
defaultInput := testutils.FormatYaml(`
2195+
kind: sources
2196+
name: http-source-default
2197+
type: http
2198+
baseUrl: https://example.com
2199+
`)
2200+
toolsFile, err = parseToolsFile(ctx, defaultInput)
2201+
if err != nil {
2202+
t.Fatalf("failed to parse input: %v", err)
2203+
}
2204+
cfg = toolsFile.Sources["http-source-default"]
2205+
if !sources.CheckAtStartup(cfg) {
2206+
t.Fatalf("expected checkAtStartup default to be true")
2207+
}
2208+
}

cmd/root.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
138138
return cmd
139139
}
140140

141-
func handleDynamicReload(ctx context.Context, toolsFile internal.ToolsFile, s *server.Server) error {
141+
func handleDynamicReload(ctx context.Context, toolsFile internal.ToolsFile, s *server.Server, allowPartialSources bool) error {
142142
logger, err := util.LoggerFromContext(ctx)
143143
if err != nil {
144144
panic(err)
145145
}
146146

147-
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
147+
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile, allowPartialSources)
148148
if err != nil {
149149
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
150150
logger.WarnContext(ctx, errMsg.Error())
@@ -158,7 +158,7 @@ func handleDynamicReload(ctx context.Context, toolsFile internal.ToolsFile, s *s
158158

159159
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
160160
func validateReloadEdits(
161-
ctx context.Context, toolsFile internal.ToolsFile,
161+
ctx context.Context, toolsFile internal.ToolsFile, allowPartialSources bool,
162162
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
163163
) {
164164
logger, err := util.LoggerFromContext(ctx)
@@ -184,6 +184,7 @@ func validateReloadEdits(
184184
ToolConfigs: toolsFile.Tools,
185185
ToolsetConfigs: toolsFile.Toolsets,
186186
PromptConfigs: toolsFile.Prompts,
187+
AllowPartialSources: allowPartialSources,
187188
}
188189

189190
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
@@ -239,7 +240,7 @@ func scanWatchedFiles(watchingFolder bool, folderToWatch string, watchedFiles ma
239240
}
240241

241242
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
242-
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int) {
243+
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server, pollTickerSecond int, allowPartialSources bool) {
243244
logger, err := util.LoggerFromContext(ctx)
244245
if err != nil {
245246
panic(err)
@@ -384,7 +385,7 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m
384385
}
385386
}
386387

387-
err = handleDynamicReload(ctx, reloadedToolsFile, s)
388+
err = handleDynamicReload(ctx, reloadedToolsFile, s, allowPartialSources)
388389
if err != nil {
389390
errMsg := fmt.Errorf("unable to parse reloaded tools file at %q: %w", reloadedToolsFile, err)
390391
logger.WarnContext(ctx, errMsg.Error())
@@ -500,7 +501,7 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error {
500501
if isCustomConfigured && !opts.Cfg.DisableReload {
501502
watchDirs, watchedFiles := resolveWatcherInputs(opts.ToolsFile, opts.ToolsFiles, opts.ToolsFolder)
502503
// start watching the file(s) or folder for changes to trigger dynamic reloading
503-
go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval)
504+
go watchChanges(ctx, watchDirs, watchedFiles, s, opts.Cfg.PollInterval, opts.Cfg.AllowPartialSources)
504505
}
505506

506507
// wait for either the server to error out or the command's context to be canceled

docs/en/reference/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ description: >
2828
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
2929
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
3030
| | `--user-agent-metadata` | Appends additional metadata to the User-Agent. | |
31+
| | `--allow-partial-sources` | Skip startup connectivity checks for sources to allow partial availability. | |
3132
| | `--poll-interval` | Specifies the polling frequency (seconds) for configuration file updates. | `0` |
3233
| `-v` | `--version` | version for toolbox | |
3334

docs/en/resources/sources/_index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ instance: my-instance-name
2626
database: my_db
2727
user: ${USER_NAME}
2828
password: ${PASSWORD}
29+
# Optional: skip startup connectivity checks for this source
30+
checkAtStartup: false
2931
```
3032
3133
In implementation, each source is a different connection pool or client that used
3234
to connect to the database and execute the tool.
3335
36+
By default, toolbox validates source connectivity at startup. Set `checkAtStartup: false`
37+
to allow a specific source to be offline without blocking startup.
38+
3439
## Available Sources

internal/server/config.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ type ServerConfig struct {
4141
Port int
4242
// SourceConfigs defines what sources of data are available for tools.
4343
SourceConfigs SourceConfigs
44+
// AllowPartialSources allows startup when some sources are unavailable by skipping connectivity checks.
45+
AllowPartialSources bool
4446
// AuthServiceConfigs defines what sources of authentication are available for tools.
4547
AuthServiceConfigs AuthServiceConfigs
4648
// EmbeddingModelConfigs defines a models used to embed parameters.
@@ -237,6 +239,17 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an
237239
if !ok {
238240
return nil, fmt.Errorf("missing 'type' field or it is not a string")
239241
}
242+
243+
checkAtStartup := true
244+
if raw, ok := r["checkAtStartup"]; ok {
245+
val, ok := raw.(bool)
246+
if !ok {
247+
return nil, fmt.Errorf("checkAtStartup must be a boolean for source %q", name)
248+
}
249+
checkAtStartup = val
250+
delete(r, "checkAtStartup")
251+
}
252+
240253
dec, err := util.NewStrictDecoder(r)
241254
if err != nil {
242255
return nil, fmt.Errorf("error creating decoder: %w", err)
@@ -245,6 +258,9 @@ func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]an
245258
if err != nil {
246259
return nil, err
247260
}
261+
if !checkAtStartup {
262+
sourceConfig = sources.WrapSourceConfigWithStartupCheck(sourceConfig, checkAtStartup)
263+
}
248264
return sourceConfig, nil
249265
}
250266

internal/server/server.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
8181

8282
// initialize and validate the sources from configs
8383
sourcesMap := make(map[string]sources.Source)
84+
if cfg.AllowPartialSources {
85+
l.WarnContext(ctx, "Skipping startup connectivity checks for all sources (allow-partial-sources enabled).")
86+
}
8487
for name, sc := range cfg.SourceConfigs {
88+
checkAtStartup := sources.CheckAtStartup(sc)
89+
if cfg.AllowPartialSources {
90+
checkAtStartup = false
91+
}
8592
s, err := func() (sources.Source, error) {
8693
childCtx, span := instrumentation.Tracer.Start(
8794
ctx,
@@ -90,6 +97,10 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
9097
trace.WithAttributes(attribute.String("source_name", name)),
9198
)
9299
defer span.End()
100+
if !checkAtStartup && !cfg.AllowPartialSources {
101+
l.InfoContext(childCtx, fmt.Sprintf("Skipping startup connectivity check for source %q", name))
102+
}
103+
childCtx = sources.WithStartupCheck(childCtx, checkAtStartup)
93104
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
94105
if err != nil {
95106
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)

internal/sources/alloydbpg/alloydb_pg.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
7171
return nil, fmt.Errorf("unable to create pool: %w", err)
7272
}
7373

74-
err = pool.Ping(ctx)
75-
if err != nil {
74+
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
75+
return pool.Ping(ctx)
76+
}); err != nil {
7677
return nil, fmt.Errorf("unable to connect successfully: %w", err)
7778
}
7879

internal/sources/clickhouse/clickhouse.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
6969
return nil, fmt.Errorf("unable to create pool: %w", err)
7070
}
7171

72-
err = pool.PingContext(ctx)
73-
if err != nil {
72+
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
73+
return pool.PingContext(ctx)
74+
}); err != nil {
7475
return nil, fmt.Errorf("unable to connect successfully: %w", err)
7576
}
7677

internal/sources/cloudsqlmssql/cloud_sql_mssql.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
7575
}
7676

7777
// Verify db connection
78-
err = db.PingContext(ctx)
79-
if err != nil {
78+
if err := sources.CheckConnectivity(ctx, func(ctx context.Context) error {
79+
return db.PingContext(ctx)
80+
}); err != nil {
8081
return nil, fmt.Errorf("unable to connect successfully: %w", err)
8182
}
8283

0 commit comments

Comments
 (0)