diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go index 4b73a5c0ae20..424f376ee1fd 100644 --- a/dashboard/app/aidb/crud.go +++ b/dashboard/app/aidb/crud.go @@ -8,6 +8,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" "cloud.google.com/go/spanner" @@ -67,7 +68,6 @@ func UpdateWorkflows(ctx context.Context, active []dashapi.AIWorkflow) error { if err != nil { return err } - defer client.Close() _, err = client.Apply(ctx, mutations) return err } @@ -109,7 +109,6 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) { if err != nil { return nil, err } - defer client.Close() var job *Job _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { { @@ -171,7 +170,6 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa if err != nil { return err } - defer client.Close() ent := TrajectorySpan{ JobID: jobID, Seq: int64(span.Seq), @@ -214,7 +212,6 @@ func selectAll[T any](ctx context.Context, stmt spanner.Statement) ([]*T, error) if err != nil { return nil, err } - defer client.Close() iter := client.Single().Query(ctx, stmt) defer iter.Stop() var items []*T @@ -236,17 +233,37 @@ func selectOne[T any](ctx context.Context, stmt spanner.Statement) (*T, error) { return all[0], nil } +var clients sync.Map // map[string]*spanner.Client + func dbClient(ctx context.Context) (*spanner.Client, error) { + appID := appengine.AppID(ctx) + if v, ok := clients.Load(appID); ok { + return v.(*spanner.Client), nil + } path := fmt.Sprintf("projects/%v/instances/%v/databases/%v", - appengine.AppID(ctx), Instance, Database) - // TODO(dvyukov): create a persistent client with a pool of connections for prod, - // but keep transient/per-test clients for tests. - return spanner.NewClientWithConfig(ctx, path, spanner.ClientConfig{ + appID, Instance, Database) + // We use background context for the client, so that it survives the request. + client, err := spanner.NewClientWithConfig(context.Background(), path, spanner.ClientConfig{ SessionPoolConfig: spanner.SessionPoolConfig{ MinOpened: 1, - MaxOpened: 1, + MaxOpened: 20, }, }) + if err != nil { + return nil, err + } + if actual, loaded := clients.LoadOrStore(appID, client); loaded { + client.Close() + return actual.(*spanner.Client), nil + } + return client, nil +} + +func CloseClient(ctx context.Context) { + appID := appengine.AppID(ctx) + if v, ok := clients.LoadAndDelete(appID); ok { + v.(*spanner.Client).Close() + } } var TimeNow = func(ctx context.Context) time.Time { diff --git a/dashboard/app/util_test.go b/dashboard/app/util_test.go index ef6abf5f9949..d2862bd220d3 100644 --- a/dashboard/app/util_test.go +++ b/dashboard/app/util_test.go @@ -301,6 +301,7 @@ func (ctx *Ctx) Close() { } } } + aidb.CloseClient(ctx.ctx) unregisterContext(ctx) validateGlobalConfig() }