Skip to content

Commit f40163f

Browse files
committed
Add support for overriding migration directory and improve code formatting
1 parent 82705e6 commit f40163f

File tree

2 files changed

+112
-14
lines changed

2 files changed

+112
-14
lines changed

mig.go

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ type Config struct {
1818
Db *sql.DB
1919

2020
// Fs is the filesystem where the migrations are stored
21-
Fs fs.FS
21+
Fs fs.FS
22+
OverrideDirName string
2223

2324
// If Fs is nil, then this slice of migrations will be used
2425
Migrations []Migration
@@ -139,12 +140,19 @@ func (mig *Mig) runUp() error {
139140
return err
140141
}
141142

142-
_, err = mig.config.Db.Exec(`
143+
_, err = mig.config.Db.Exec(
144+
`
143145
INSERT INTO
144146
migrations (id, filename, raw, hash, up, down)
145147
VALUES
146148
($1, $2, $3, $4, $5, $6)`,
147-
m.Id, m.FileName, m.raw, m.hash, m.Up, m.Down)
149+
m.Id,
150+
m.FileName,
151+
m.raw,
152+
m.hash,
153+
m.Up,
154+
m.Down,
155+
)
148156
if err != nil {
149157
return err
150158
}
@@ -166,7 +174,11 @@ func (mig *Mig) runDown() error {
166174
continue
167175
}
168176
if dbMig.Id != mig.config.Migrations[i].Id {
169-
return fmt.Errorf("mismatched migration id: dbMig.Id=%d, mig.config.Migrations[i].Id=%d", dbMig.Id, mig.config.Migrations[i].Id)
177+
return fmt.Errorf(
178+
"mismatched migration id: dbMig.Id=%d, mig.config.Migrations[i].Id=%d",
179+
dbMig.Id,
180+
mig.config.Migrations[i].Id,
181+
)
170182
}
171183
if dbMig.hash != mig.config.Migrations[i].hash {
172184
return mig.runDownTo(dbMig.Id)
@@ -200,7 +212,10 @@ func (mig *Mig) runDownTo(endId int) error {
200212
}
201213

202214
// remove migration from migrations table
203-
_, err = mig.config.Db.Exec("DELETE FROM migrations WHERE id = $1", dbMigrations[i].Id)
215+
_, err = mig.config.Db.Exec(
216+
"DELETE FROM migrations WHERE id = $1",
217+
dbMigrations[i].Id,
218+
)
204219
if err != nil {
205220
return fmt.Errorf("error deleting migration from migrations table: %w", err)
206221
}
@@ -210,33 +225,67 @@ func (mig *Mig) runDownTo(endId int) error {
210225
}
211226

212227
func (mig *Mig) getMigrationsFromFS() ([]Migration, error) {
213-
result := []Migration{}
214-
215-
entries, err := fs.ReadDir(mig.config.Fs, ".")
216-
if err != nil {
217-
return nil, err
228+
var (
229+
result []Migration
230+
entries []fs.DirEntry
231+
err error
232+
)
233+
234+
if mig.config.OverrideDirName == "" {
235+
entries, err = fs.ReadDir(
236+
mig.config.Fs,
237+
".",
238+
)
239+
if err != nil {
240+
return nil, err
241+
}
242+
} else {
243+
entries, err = fs.ReadDir(
244+
mig.config.Fs,
245+
mig.config.OverrideDirName,
246+
)
247+
if err != nil {
248+
return nil, err
249+
}
218250
}
219251

220252
for _, entry := range entries {
221253
if entry.IsDir() {
222254
continue
223255
}
224256

225-
m := Migration{}
257+
var (
258+
m = Migration{}
259+
contents []byte
260+
err error
261+
)
262+
226263
m.FileName = entry.Name()
227264
m.Id, err = getIntFromFileName(m.FileName)
228265
if err != nil {
229266
return nil, err
230267
}
231268

232-
contents, err := fs.ReadFile(mig.config.Fs, entry.Name())
269+
if mig.config.OverrideDirName != "" {
270+
contents, err = fs.ReadFile(
271+
mig.config.Fs,
272+
fmt.Sprintf("%s/%s", mig.config.OverrideDirName, m.FileName),
273+
)
274+
} else {
275+
contents, err = fs.ReadFile(mig.config.Fs, m.FileName)
276+
}
277+
233278
if err != nil {
234279
return nil, err
235280
}
236281
m.raw = string(contents)
237282
m.hash = hashRaw(m.raw)
238283

239-
m.Up, m.Down, err = splitRaw(m.raw, mig.config.UpDelimiter, mig.config.DownDelimiter)
284+
m.Up, m.Down, err = splitRaw(
285+
m.raw,
286+
mig.config.UpDelimiter,
287+
mig.config.DownDelimiter,
288+
)
240289
if err != nil {
241290
return nil, err
242291
}
@@ -260,7 +309,14 @@ func (mig *Mig) getMigrationsFromDB() ([]Migration, error) {
260309
result := []Migration{}
261310
for rows.Next() {
262311
m := Migration{}
263-
err = rows.Scan(&m.Id, &m.FileName, &m.raw, &m.hash, &m.Up, &m.Down)
312+
err = rows.Scan(
313+
&m.Id,
314+
&m.FileName,
315+
&m.raw,
316+
&m.hash,
317+
&m.Up,
318+
&m.Down,
319+
)
264320
if err != nil {
265321
return nil, err
266322
}

sqlite_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ package mig
22

33
import (
44
"database/sql"
5+
"embed"
56
"os"
67
"testing"
78

89
_ "github.com/mattn/go-sqlite3"
910
"github.com/stretchr/testify/assert"
1011
)
1112

13+
//go:embed test/migrations1
14+
var migrationsFS embed.FS
15+
1216
func TestMigrate(t *testing.T) {
1317
t.Run("fresh migrate runs successfully", func(t *testing.T) {
1418
testDbPath := "./test/test2.db"
@@ -291,6 +295,44 @@ func TestMigrate(t *testing.T) {
291295

292296
os.Remove(testDbPath)
293297
})
298+
299+
t.Run("migrating from embedded FS works", func(t *testing.T) {
300+
testDbPath := "./test/test8.db"
301+
db, err := sql.Open("sqlite3", testDbPath)
302+
assert.Nil(t, err)
303+
defer db.Close()
304+
305+
m, err := New(Config{
306+
Db: db,
307+
308+
Fs: migrationsFS,
309+
OverrideDirName: "test/migrations1",
310+
})
311+
assert.Nil(t, err)
312+
313+
err = m.Migrate()
314+
assert.Nil(t, err)
315+
316+
tableMustExistSqlite(t, db, "migrations")
317+
tableMustExistSqlite(t, db, "test_table_1")
318+
tableMustExistSqlite(t, db, "test_table_2")
319+
tableMustExistSqlite(t, db, "test_table_3")
320+
tableMustNotExistSqlite(t, db, "test_table_4")
321+
322+
m.config.Migrations[2].Up = "CREATE TABLE test_table_5 (id INTEGER PRIMARY KEY, name TEXT);"
323+
m.config.Migrations[2].Down = "DROP TABLE test_table_5;"
324+
err = m.Migrate()
325+
assert.Nil(t, err)
326+
327+
tableMustExistSqlite(t, db, "migrations")
328+
tableMustExistSqlite(t, db, "test_table_1")
329+
tableMustExistSqlite(t, db, "test_table_2")
330+
tableMustNotExistSqlite(t, db, "test_table_3")
331+
tableMustNotExistSqlite(t, db, "test_table_4")
332+
tableMustExistSqlite(t, db, "test_table_5")
333+
334+
os.Remove(testDbPath)
335+
})
294336
}
295337

296338
func tableMustExistSqlite(t *testing.T, db *sql.DB, tableName string) {

0 commit comments

Comments
 (0)