@@ -287,13 +287,7 @@ func Migrate(ctx context.Context, conn *sqlite.Conn, schema Schema) error {
287
287
func migrateDB (ctx context.Context , conn * sqlite.Conn , schema Schema , onStart SignalFunc ) error {
288
288
defer conn .SetInterrupt (conn .SetInterrupt (ctx .Done ()))
289
289
290
- userVersionStmt , _ , err := conn .PrepareTransient ("PRAGMA user_version;" )
291
- if err != nil {
292
- return fmt .Errorf ("migrate database: %w" , err )
293
- }
294
- defer userVersionStmt .Finalize ()
295
-
296
- schemaVersion , err := ensureAppID (conn , schema .AppID , userVersionStmt )
290
+ schemaVersion , err := ensureAppID (conn , schema .AppID )
297
291
if err != nil {
298
292
return fmt .Errorf ("migrate database: %w" , err )
299
293
}
@@ -303,26 +297,13 @@ func migrateDB(ctx context.Context, conn *sqlite.Conn, schema Schema, onStart Si
303
297
var foreignKeysEnabled bool
304
298
err = sqlitex .ExecuteTransient (conn , "PRAGMA foreign_keys;" , & sqlitex.ExecOptions {
305
299
ResultFunc : func (stmt * sqlite.Stmt ) error {
306
- foreignKeysEnabled = stmt .ColumnInt (0 ) != 0
300
+ foreignKeysEnabled = stmt .ColumnBool (0 )
307
301
return nil
308
302
},
309
303
})
310
304
if err != nil {
311
305
return fmt .Errorf ("migrate database: %w" , err )
312
306
}
313
- var fkOnStmt , fkOffStmt * sqlite.Stmt
314
- if foreignKeysEnabled {
315
- fkOnStmt , _ , err = conn .PrepareTransient ("PRAGMA foreign_keys = on;" )
316
- if err != nil {
317
- return fmt .Errorf ("migrate database: %w" , err )
318
- }
319
- defer fkOnStmt .Finalize ()
320
- fkOffStmt , _ , err = conn .PrepareTransient ("PRAGMA foreign_keys = off;" )
321
- if err != nil {
322
- return fmt .Errorf ("migrate database: %w" , err )
323
- }
324
- defer fkOffStmt .Finalize ()
325
- }
326
307
327
308
beginStmt , _ , err := conn .PrepareTransient ("BEGIN IMMEDIATE;" )
328
309
if err != nil {
@@ -334,27 +315,24 @@ func migrateDB(ctx context.Context, conn *sqlite.Conn, schema Schema, onStart Si
334
315
return fmt .Errorf ("migrate database: %w" , err )
335
316
}
336
317
defer commitStmt .Finalize ()
337
- for ; schemaVersion < len (schema .Migrations ); schemaVersion ++ {
318
+ for ; int ( schemaVersion ) < len (schema .Migrations ); schemaVersion ++ {
338
319
migration := schema .Migrations [schemaVersion ]
339
320
disableFKs := foreignKeysEnabled &&
340
- schemaVersion < len (schema .MigrationOptions ) &&
321
+ int ( schemaVersion ) < len (schema .MigrationOptions ) &&
341
322
schema .MigrationOptions [schemaVersion ] != nil &&
342
323
schema .MigrationOptions [schemaVersion ].DisableForeignKeys
343
324
if disableFKs {
344
- if err := stepAndReset (fkOffStmt ); err != nil {
325
+ // Do not try to optimize by preparing this PRAGMA statement ahead of time.
326
+ if err := sqlitex .ExecuteTransient (conn , "PRAGMA foreign_keys = off;" , nil ); err != nil {
345
327
return fmt .Errorf ("migrate database: disable foreign keys: %w" , err )
346
328
}
347
329
}
348
330
349
331
if err := stepAndReset (beginStmt ); err != nil {
350
332
return fmt .Errorf ("migrate database: apply migrations[%d]: %w" , schemaVersion , err )
351
333
}
352
- if _ , err := userVersionStmt .Step (); err != nil {
353
- rollback (conn )
354
- return fmt .Errorf ("migrate database: %w" , err )
355
- }
356
- actualSchemaVersion := userVersionStmt .ColumnInt (0 )
357
- if err := userVersionStmt .Reset (); err != nil {
334
+ actualSchemaVersion , err := userVersion (conn )
335
+ if err != nil {
358
336
rollback (conn )
359
337
return fmt .Errorf ("migrate database: %w" , err )
360
338
}
@@ -365,12 +343,12 @@ func migrateDB(ctx context.Context, conn *sqlite.Conn, schema Schema, onStart Si
365
343
continue
366
344
}
367
345
368
- err : = sqlitex .ExecScript (conn , fmt .Sprintf ("%s;\n PRAGMA user_version = %d;\n " , migration , schemaVersion + 1 ))
346
+ err = sqlitex .ExecScript (conn , fmt .Sprintf ("%s;\n PRAGMA user_version = %d;\n " , migration , schemaVersion + 1 ))
369
347
if err != nil {
370
348
rollback (conn )
371
349
return fmt .Errorf ("migrate database: apply migrations[%d]: %w" , schemaVersion , err )
372
350
}
373
- if schemaVersion == len (schema .Migrations )- 1 && schema .RepeatableMigration != "" {
351
+ if int ( schemaVersion ) == len (schema .Migrations )- 1 && schema .RepeatableMigration != "" {
374
352
if err := sqlitex .ExecScript (conn , schema .RepeatableMigration ); err != nil {
375
353
rollback (conn )
376
354
return fmt .Errorf ("migrate database: apply repeatable migration: %w" , err )
@@ -382,22 +360,36 @@ func migrateDB(ctx context.Context, conn *sqlite.Conn, schema Schema, onStart Si
382
360
return fmt .Errorf ("migrate database: apply migrations[%d]: %w" , schemaVersion , err )
383
361
}
384
362
if disableFKs {
385
- if err := stepAndReset ( fkOnStmt ); err != nil {
363
+ if err := sqlitex . ExecuteTransient ( conn , "PRAGMA foreign_keys = on;" , nil ); err != nil {
386
364
return fmt .Errorf ("migrate database: reenable foreign keys: %w" , err )
387
365
}
388
366
}
389
367
}
390
368
return nil
391
369
}
392
370
371
+ func userVersion (conn * sqlite.Conn ) (int32 , error ) {
372
+ var version int32
373
+ err := sqlitex .ExecuteTransient (conn , "PRAGMA user_version;" , & sqlitex.ExecOptions {
374
+ ResultFunc : func (stmt * sqlite.Stmt ) error {
375
+ version = stmt .ColumnInt32 (0 )
376
+ return nil
377
+ },
378
+ })
379
+ if err != nil {
380
+ return 0 , fmt .Errorf ("get database user_version: %w" , err )
381
+ }
382
+ return version , nil
383
+ }
384
+
393
385
func rollback (conn * sqlite.Conn ) {
394
386
if conn .AutocommitEnabled () {
395
387
return
396
388
}
397
389
sqlitex .ExecuteTransient (conn , "ROLLBACK;" , nil )
398
390
}
399
391
400
- func ensureAppID (conn * sqlite.Conn , wantAppID int32 , userVersionStmt * sqlite. Stmt ) (schemaVersion int , err error ) {
392
+ func ensureAppID (conn * sqlite.Conn , wantAppID int32 ) (schemaVersion int32 , err error ) {
401
393
defer sqlitex .Save (conn )(& err )
402
394
403
395
var hasSchema bool
@@ -423,11 +415,8 @@ func ensureAppID(conn *sqlite.Conn, wantAppID int32, userVersionStmt *sqlite.Stm
423
415
if dbAppID != wantAppID && ! (dbAppID == 0 && ! hasSchema ) {
424
416
return 0 , fmt .Errorf ("database application_id = %#x (expected %#x)" , dbAppID , wantAppID )
425
417
}
426
- if _ , err := userVersionStmt .Step (); err != nil {
427
- return 0 , err
428
- }
429
- schemaVersion = userVersionStmt .ColumnInt (0 )
430
- if err := userVersionStmt .Reset (); err != nil {
418
+ schemaVersion , err = userVersion (conn )
419
+ if err != nil {
431
420
return 0 , err
432
421
}
433
422
// Using Sprintf because PRAGMAs don't permit arbitrary expressions, and thus
0 commit comments