Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions dashboard/app/aidb/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"

"cloud.google.com/go/spanner"
Expand Down Expand Up @@ -67,7 +68,6 @@ func UpdateWorkflows(ctx context.Context, active []dashapi.AIWorkflow) error {
if err != nil {
return err
}
defer client.Close()
_, err = client.Apply(ctx, mutations)
return err
}
Expand Down Expand Up @@ -109,7 +109,6 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
if err != nil {
return nil, err
}
defer client.Close()
var job *Job
_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
{
Expand Down Expand Up @@ -171,7 +170,6 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa
if err != nil {
return err
}
defer client.Close()
ent := TrajectorySpan{
JobID: jobID,
Seq: int64(span.Seq),
Expand Down Expand Up @@ -214,7 +212,6 @@ func selectAll[T any](ctx context.Context, stmt spanner.Statement) ([]*T, error)
if err != nil {
return nil, err
}
defer client.Close()
iter := client.Single().Query(ctx, stmt)
defer iter.Stop()
var items []*T
Expand All @@ -236,17 +233,37 @@ func selectOne[T any](ctx context.Context, stmt spanner.Statement) (*T, error) {
return all[0], nil
}

var clients sync.Map // map[string]*spanner.Client

func dbClient(ctx context.Context) (*spanner.Client, error) {
appID := appengine.AppID(ctx)
if v, ok := clients.Load(appID); ok {
return v.(*spanner.Client), nil
}
path := fmt.Sprintf("projects/%v/instances/%v/databases/%v",
appengine.AppID(ctx), Instance, Database)
// TODO(dvyukov): create a persistent client with a pool of connections for prod,
// but keep transient/per-test clients for tests.
return spanner.NewClientWithConfig(ctx, path, spanner.ClientConfig{
appID, Instance, Database)
// We use background context for the client, so that it survives the request.
client, err := spanner.NewClientWithConfig(context.Background(), path, spanner.ClientConfig{
SessionPoolConfig: spanner.SessionPoolConfig{
MinOpened: 1,
MaxOpened: 1,
MaxOpened: 20,
},
})
if err != nil {
return nil, err
}
if actual, loaded := clients.LoadOrStore(appID, client); loaded {
client.Close()
return actual.(*spanner.Client), nil
}
return client, nil
}

func CloseClient(ctx context.Context) {
appID := appengine.AppID(ctx)
if v, ok := clients.LoadAndDelete(appID); ok {
v.(*spanner.Client).Close()
}
}

var TimeNow = func(ctx context.Context) time.Time {
Expand Down
1 change: 1 addition & 0 deletions dashboard/app/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ func (ctx *Ctx) Close() {
}
}
}
aidb.CloseClient(ctx.ctx)
unregisterContext(ctx)
validateGlobalConfig()
}
Expand Down
Loading