88 "fmt"
99 "reflect"
1010 "strings"
11+ "sync"
1112 "time"
1213
1314 "cloud.google.com/go/spanner"
@@ -67,7 +68,6 @@ func UpdateWorkflows(ctx context.Context, active []dashapi.AIWorkflow) error {
6768 if err != nil {
6869 return err
6970 }
70- defer client .Close ()
7171 _ , err = client .Apply (ctx , mutations )
7272 return err
7373}
@@ -109,7 +109,6 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
109109 if err != nil {
110110 return nil , err
111111 }
112- defer client .Close ()
113112 var job * Job
114113 _ , err = client .ReadWriteTransaction (ctx , func (ctx context.Context , txn * spanner.ReadWriteTransaction ) error {
115114 {
@@ -171,7 +170,6 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa
171170 if err != nil {
172171 return err
173172 }
174- defer client .Close ()
175173 ent := TrajectorySpan {
176174 JobID : jobID ,
177175 Seq : int64 (span .Seq ),
@@ -214,7 +212,6 @@ func selectAll[T any](ctx context.Context, stmt spanner.Statement) ([]*T, error)
214212 if err != nil {
215213 return nil , err
216214 }
217- defer client .Close ()
218215 iter := client .Single ().Query (ctx , stmt )
219216 defer iter .Stop ()
220217 var items []* T
@@ -236,17 +233,37 @@ func selectOne[T any](ctx context.Context, stmt spanner.Statement) (*T, error) {
236233 return all [0 ], nil
237234}
238235
236+ var clients sync.Map // map[string]*spanner.Client
237+
239238func dbClient (ctx context.Context ) (* spanner.Client , error ) {
239+ appID := appengine .AppID (ctx )
240+ if v , ok := clients .Load (appID ); ok {
241+ return v .(* spanner.Client ), nil
242+ }
240243 path := fmt .Sprintf ("projects/%v/instances/%v/databases/%v" ,
241- appengine .AppID (ctx ), Instance , Database )
242- // TODO(dvyukov): create a persistent client with a pool of connections for prod,
243- // but keep transient/per-test clients for tests.
244- return spanner .NewClientWithConfig (ctx , path , spanner.ClientConfig {
244+ appID , Instance , Database )
245+ // We use background context for the client, so that it survives the request.
246+ client , err := spanner .NewClientWithConfig (context .Background (), path , spanner.ClientConfig {
245247 SessionPoolConfig : spanner.SessionPoolConfig {
246248 MinOpened : 1 ,
247- MaxOpened : 1 ,
249+ MaxOpened : 20 ,
248250 },
249251 })
252+ if err != nil {
253+ return nil , err
254+ }
255+ if actual , loaded := clients .LoadOrStore (appID , client ); loaded {
256+ client .Close ()
257+ return actual .(* spanner.Client ), nil
258+ }
259+ return client , nil
260+ }
261+
262+ func CloseClient (ctx context.Context ) {
263+ appID := appengine .AppID (ctx )
264+ if v , ok := clients .LoadAndDelete (appID ); ok {
265+ v .(* spanner.Client ).Close ()
266+ }
250267}
251268
252269var TimeNow = func (ctx context.Context ) time.Time {
0 commit comments