@@ -55,20 +55,15 @@ func New(ctx context.Context, beadsDir, database, branch string) (*EmbeddedDoltS
5555 return s , nil
5656}
5757
58- // withConn opens a short-lived database connection, begins an explicit SQL
59- // transaction, and passes it to fn. If commit is true and fn returns nil, the
60- // transaction is committed; otherwise it is rolled back. The connection is
61- // closed before withConn returns regardless of outcome.
62- func (s * EmbeddedDoltStore ) withConn (ctx context.Context , commit bool , fn func (tx * sql.Tx ) error ) (err error ) {
58+ // withRootConn opens a short-lived database connection without selecting any
59+ // database or branch, begins an explicit SQL transaction, and passes it to fn.
60+ // This is used during initialization when the database may not yet exist.
61+ func (s * EmbeddedDoltStore ) withRootConn (ctx context.Context , commit bool , fn func (tx * sql.Tx ) error ) (err error ) {
6362 if s .closed .Load () {
6463 err = errClosed
6564 return
6665 }
6766
68- if s .database != "" && ! validIdentifier .MatchString (s .database ) {
69- return fmt .Errorf ("embeddeddolt: invalid database name: %q" , s .database )
70- }
71-
7267 var db * sql.DB
7368 var cleanup func () error
7469 db , cleanup , err = OpenSQL (ctx , s .dataDir , "" , "" )
@@ -80,20 +75,51 @@ func (s *EmbeddedDoltStore) withConn(ctx context.Context, commit bool, fn func(t
8075 err = errors .Join (err , cleanup ())
8176 }()
8277
83- if s .database != "" {
84- if _ , err = db .ExecContext (ctx , "CREATE DATABASE IF NOT EXISTS `" + s .database + "`" ); err != nil {
85- return fmt .Errorf ("embeddeddolt: creating database: %w" , err )
86- }
87- if _ , err = db .ExecContext (ctx , "USE `" + s .database + "`" ); err != nil {
88- return fmt .Errorf ("embeddeddolt: switching to database: %w" , err )
89- }
90- if s .branch != "" {
91- if _ , err = db .ExecContext (ctx , fmt .Sprintf ("SET @@%s_head_ref = %s" , s .database , sqlStringLiteral (s .branch ))); err != nil {
92- return fmt .Errorf ("embeddeddolt: setting branch: %w" , err )
93- }
94- }
78+ var tx * sql.Tx
79+ tx , err = db .BeginTx (ctx , nil )
80+ if err != nil {
81+ err = fmt .Errorf ("embeddeddolt: begin tx: %w" , err )
82+ return
83+ }
84+
85+ err = fn (tx )
86+ if err != nil {
87+ err = errors .Join (err , tx .Rollback ())
88+ return
89+ }
90+
91+ if ! commit {
92+ return tx .Rollback ()
93+ }
94+
95+ err = tx .Commit ()
96+ return
97+ }
98+
99+ // withConn opens a short-lived database connection configured for the store's
100+ // database and branch, begins an explicit SQL transaction, and passes it to
101+ // fn. If commit is true and fn returns nil, the transaction is committed;
102+ // otherwise it is rolled back. The connection is closed before withConn
103+ // returns regardless of outcome.
104+ //
105+ // The database must already exist (created during initSchema).
106+ func (s * EmbeddedDoltStore ) withConn (ctx context.Context , commit bool , fn func (tx * sql.Tx ) error ) (err error ) {
107+ if s .closed .Load () {
108+ err = errClosed
109+ return
110+ }
111+
112+ var db * sql.DB
113+ var cleanup func () error
114+ db , cleanup , err = OpenSQL (ctx , s .dataDir , s .database , s .branch )
115+ if err != nil {
116+ return
95117 }
96118
119+ defer func () {
120+ err = errors .Join (err , cleanup ())
121+ }()
122+
97123 var tx * sql.Tx
98124 tx , err = db .BeginTx (ctx , nil )
99125 if err != nil {
@@ -115,9 +141,29 @@ func (s *EmbeddedDoltStore) withConn(ctx context.Context, commit bool, fn func(t
115141 return
116142}
117143
118- // initSchema runs all pending migrations and commits them to Dolt history.
144+ // initSchema creates the database (if needed) and runs all pending migrations,
145+ // committing them to Dolt history. Uses withRootConn so the database can be
146+ // created before USE; this avoids running CREATE DATABASE inside withConn,
147+ // which is not safe for concurrent use in the embedded Dolt engine.
119148func (s * EmbeddedDoltStore ) initSchema (ctx context.Context ) error {
120- return s .withConn (ctx , true , func (tx * sql.Tx ) error {
149+ return s .withRootConn (ctx , true , func (tx * sql.Tx ) error {
150+ if s .database != "" {
151+ if ! validIdentifier .MatchString (s .database ) {
152+ return fmt .Errorf ("embeddeddolt: invalid database name: %q" , s .database )
153+ }
154+ if _ , err := tx .ExecContext (ctx , "CREATE DATABASE IF NOT EXISTS `" + s .database + "`" ); err != nil {
155+ return fmt .Errorf ("embeddeddolt: creating database: %w" , err )
156+ }
157+ if _ , err := tx .ExecContext (ctx , "USE `" + s .database + "`" ); err != nil {
158+ return fmt .Errorf ("embeddeddolt: switching to database: %w" , err )
159+ }
160+ if s .branch != "" {
161+ if _ , err := tx .ExecContext (ctx , fmt .Sprintf ("SET @@%s_head_ref = %s" , s .database , sqlStringLiteral (s .branch ))); err != nil {
162+ return fmt .Errorf ("embeddeddolt: setting branch: %w" , err )
163+ }
164+ }
165+ }
166+
121167 applied , err := migrateUp (ctx , tx )
122168 if err != nil {
123169 return err
0 commit comments