Skip to content

Commit 11ba682

Browse files
committed
support recursive collection of source files for provider using fs.FS
1 parent 76946cc commit 11ba682

File tree

4 files changed

+141
-62
lines changed

4 files changed

+141
-62
lines changed

provider.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func newProvider(
107107
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
108108
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
109109
// we should make it optional.
110-
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
110+
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions, cfg.recursive)
111111
if err != nil {
112112
return nil, err
113113
}

provider_collect.go

+97-47
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,54 @@ type fileSources struct {
1515
goSources []Source
1616
}
1717

18+
func checkFile(fullpath string, strict bool, excludePaths map[string]bool, excludeVersions map[int64]bool, versionToBaseLookup map[int64]string) (Source, bool, error) {
19+
base := filepath.Base(fullpath)
20+
if strings.HasSuffix(base, "_test.go") {
21+
return Source{}, false, nil
22+
}
23+
if excludePaths[base] {
24+
// TODO(mf): log this?
25+
return Source{}, false, nil
26+
}
27+
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
28+
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
29+
// filenames, but still have versioned migrations within the same directory. For
30+
// example, a user could have a helpers.go file which contains unexported helper
31+
// functions for migrations.
32+
version, err := NumericComponent(base)
33+
if err != nil {
34+
if strict {
35+
return Source{}, false, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
36+
}
37+
return Source{}, false, nil
38+
}
39+
if excludeVersions[version] {
40+
// TODO: log this?
41+
return Source{}, false, nil
42+
}
43+
// Ensure there are no duplicate versions.
44+
if existing, ok := versionToBaseLookup[version]; ok {
45+
return Source{}, false, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
46+
version,
47+
existing,
48+
base,
49+
)
50+
}
51+
source := Source{Path: fullpath, Version: version}
52+
switch filepath.Ext(base) {
53+
case ".sql":
54+
source.Type = TypeSQL
55+
case ".go":
56+
source.Type = TypeGo
57+
default:
58+
// Should never happen since we already filtered out all other file types.
59+
return Source{}, false, fmt.Errorf("invalid file extension: %q", base)
60+
}
61+
// Add the version to the lookup map.
62+
versionToBaseLookup[version] = base
63+
return source, true, nil
64+
}
65+
1866
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
1967
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
2068
// be nil, in which case an empty fileSources is returned.
@@ -29,6 +77,7 @@ func collectFilesystemSources(
2977
strict bool,
3078
excludePaths map[string]bool,
3179
excludeVersions map[int64]bool,
80+
recursive bool,
3281
) (*fileSources, error) {
3382
if fsys == nil {
3483
return new(fileSources), nil
@@ -39,65 +88,66 @@ func collectFilesystemSources(
3988
"*.sql",
4089
"*.go",
4190
} {
42-
files, err := fs.Glob(fsys, pattern)
91+
files, err := func() ([]string, error) {
92+
if recursive {
93+
var files []string
94+
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
95+
if err != nil {
96+
return err
97+
}
98+
if d.IsDir() {
99+
subFs, err := fs.Sub(fsys, path)
100+
if err != nil {
101+
return err
102+
}
103+
dirFiles, err := fs.Glob(subFs, pattern)
104+
for _, file := range dirFiles {
105+
files = append(files, filepath.Join(path, file))
106+
}
107+
}
108+
return nil
109+
})
110+
if err != nil {
111+
return nil, err
112+
}
113+
return files, nil
114+
} else {
115+
files, err := fs.Glob(fsys, pattern)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
118+
}
119+
return files, nil
120+
}
121+
}()
43122
if err != nil {
44-
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
123+
return nil, err
45124
}
46125
for _, fullpath := range files {
47-
base := filepath.Base(fullpath)
48-
if strings.HasSuffix(base, "_test.go") {
49-
continue
50-
}
51-
if excludePaths[base] {
52-
// TODO(mf): log this?
53-
continue
54-
}
55-
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
56-
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
57-
// filenames, but still have versioned migrations within the same directory. For
58-
// example, a user could have a helpers.go file which contains unexported helper
59-
// functions for migrations.
60-
version, err := NumericComponent(base)
126+
source, isValid, err := checkFile(
127+
fullpath,
128+
strict,
129+
excludePaths,
130+
excludeVersions,
131+
versionToBaseLookup,
132+
)
61133
if err != nil {
62-
if strict {
63-
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
64-
}
65-
continue
134+
return nil, err
66135
}
67-
if excludeVersions[version] {
68-
// TODO: log this?
136+
if !isValid {
69137
continue
70138
}
71-
// Ensure there are no duplicate versions.
72-
if existing, ok := versionToBaseLookup[version]; ok {
73-
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
74-
version,
75-
existing,
76-
base,
77-
)
78-
}
79-
switch filepath.Ext(base) {
80-
case ".sql":
81-
sources.sqlSources = append(sources.sqlSources, Source{
82-
Type: TypeSQL,
83-
Path: fullpath,
84-
Version: version,
85-
})
86-
case ".go":
87-
sources.goSources = append(sources.goSources, Source{
88-
Type: TypeGo,
89-
Path: fullpath,
90-
Version: version,
91-
})
139+
switch source.Type {
140+
case TypeSQL:
141+
sources.sqlSources = append(sources.sqlSources, source)
142+
case TypeGo:
143+
sources.goSources = append(sources.goSources, source)
92144
default:
93-
// Should never happen since we already filtered out all other file types.
94-
return nil, fmt.Errorf("invalid file extension: %q", base)
145+
return nil, errors.New("unreachable")
95146
}
96-
// Add the version to the lookup map.
97-
versionToBaseLookup[version] = base
98147
}
99148
}
100149
return sources, nil
150+
101151
}
102152

103153
func newSQLMigration(source Source) *Migration {

provider_collect_test.go

+35-14
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ import (
1212
func TestCollectFileSources(t *testing.T) {
1313
t.Parallel()
1414
t.Run("nil_fsys", func(t *testing.T) {
15-
sources, err := collectFilesystemSources(nil, false, nil, nil)
15+
sources, err := collectFilesystemSources(nil, false, nil, nil, false)
1616
check.NoError(t, err)
1717
check.Bool(t, sources != nil, true)
1818
check.Number(t, len(sources.goSources), 0)
1919
check.Number(t, len(sources.sqlSources), 0)
2020
})
2121
t.Run("noop_fsys", func(t *testing.T) {
22-
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
22+
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil, false)
2323
check.NoError(t, err)
2424
check.Bool(t, sources != nil, true)
2525
check.Number(t, len(sources.goSources), 0)
2626
check.Number(t, len(sources.sqlSources), 0)
2727
})
2828
t.Run("empty_fsys", func(t *testing.T) {
29-
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
29+
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil, false)
3030
check.NoError(t, err)
3131
check.Number(t, len(sources.goSources), 0)
3232
check.Number(t, len(sources.sqlSources), 0)
@@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
3737
"00000_foo.sql": sqlMapFile,
3838
}
3939
// strict disable - should not error
40-
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
40+
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
4141
check.NoError(t, err)
4242
check.Number(t, len(sources.goSources), 0)
4343
check.Number(t, len(sources.sqlSources), 0)
4444
// strict enabled - should error
45-
_, err = collectFilesystemSources(mapFS, true, nil, nil)
45+
_, err = collectFilesystemSources(mapFS, true, nil, nil, false)
4646
check.HasError(t, err)
4747
check.Contains(t, err.Error(), "migration version must be greater than zero")
4848
})
4949
t.Run("collect", func(t *testing.T) {
5050
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
5151
check.NoError(t, err)
52-
sources, err := collectFilesystemSources(fsys, false, nil, nil)
52+
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
5353
check.NoError(t, err)
5454
check.Number(t, len(sources.sqlSources), 4)
5555
check.Number(t, len(sources.goSources), 0)
@@ -77,6 +77,7 @@ func TestCollectFileSources(t *testing.T) {
7777
"00110_qux.sql": true,
7878
},
7979
nil,
80+
false,
8081
)
8182
check.NoError(t, err)
8283
check.Number(t, len(sources.sqlSources), 2)
@@ -97,7 +98,7 @@ func TestCollectFileSources(t *testing.T) {
9798
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
9899
fsys, err := fs.Sub(mapFS, "migrations")
99100
check.NoError(t, err)
100-
_, err = collectFilesystemSources(fsys, true, nil, nil)
101+
_, err = collectFilesystemSources(fsys, true, nil, nil, false)
101102
check.HasError(t, err)
102103
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
103104
})
@@ -109,7 +110,7 @@ func TestCollectFileSources(t *testing.T) {
109110
"4_qux.sql": sqlMapFile,
110111
"5_foo_test.go": {Data: []byte(`package goose_test`)},
111112
}
112-
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
113+
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
113114
check.NoError(t, err)
114115
check.Number(t, len(sources.sqlSources), 4)
115116
check.Number(t, len(sources.goSources), 0)
@@ -124,7 +125,7 @@ func TestCollectFileSources(t *testing.T) {
124125
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
125126
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
126127
}
127-
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
128+
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
128129
check.NoError(t, err)
129130
check.Number(t, len(sources.sqlSources), 2)
130131
check.Number(t, len(sources.goSources), 1)
@@ -143,7 +144,8 @@ func TestCollectFileSources(t *testing.T) {
143144
"001_foo.sql": sqlMapFile,
144145
"01_bar.sql": sqlMapFile,
145146
}
146-
_, err := collectFilesystemSources(mapFS, false, nil, nil)
147+
148+
_, err := collectFilesystemSources(mapFS, false, nil, nil, false)
147149
check.HasError(t, err)
148150
check.Contains(t, err.Error(), "found duplicate migration version 1")
149151
})
@@ -159,7 +161,7 @@ func TestCollectFileSources(t *testing.T) {
159161
t.Helper()
160162
f, err := fs.Sub(mapFS, dirpath)
161163
check.NoError(t, err)
162-
got, err := collectFilesystemSources(f, false, nil, nil)
164+
got, err := collectFilesystemSources(f, false, nil, nil, false)
163165
check.NoError(t, err)
164166
check.Number(t, len(got.sqlSources), len(sqlSources))
165167
check.Number(t, len(got.goSources), 0)
@@ -180,6 +182,25 @@ func TestCollectFileSources(t *testing.T) {
180182
})
181183
assertDirpath("dir3", nil)
182184
})
185+
t.Run("recursive", func(t *testing.T) {
186+
mapFS := fstest.MapFS{
187+
"876_a.sql": sqlMapFile,
188+
"dir1/101_a.sql": sqlMapFile,
189+
"dir1/102_b.sql": sqlMapFile,
190+
"dir1/103_c.sql": sqlMapFile,
191+
"dir2/201_a.sql": sqlMapFile,
192+
"dir2/dir3/301_a.sql": sqlMapFile,
193+
}
194+
sources, err := collectFilesystemSources(mapFS, false, nil, nil, true)
195+
check.NoError(t, err)
196+
check.Equal(t, len(sources.sqlSources), 6)
197+
check.Equal(t, sources.sqlSources[0].Path, "876_a.sql")
198+
check.Equal(t, sources.sqlSources[1].Path, "dir1/101_a.sql")
199+
check.Equal(t, sources.sqlSources[2].Path, "dir1/102_b.sql")
200+
check.Equal(t, sources.sqlSources[3].Path, "dir1/103_c.sql")
201+
check.Equal(t, sources.sqlSources[4].Path, "dir2/201_a.sql")
202+
check.Equal(t, sources.sqlSources[5].Path, "dir2/dir3/301_a.sql")
203+
})
183204
}
184205

185206
func TestMerge(t *testing.T) {
@@ -195,7 +216,7 @@ func TestMerge(t *testing.T) {
195216
}
196217
fsys, err := fs.Sub(mapFS, "migrations")
197218
check.NoError(t, err)
198-
sources, err := collectFilesystemSources(fsys, false, nil, nil)
219+
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
199220
check.NoError(t, err)
200221
check.Equal(t, len(sources.sqlSources), 1)
201222
check.Equal(t, len(sources.goSources), 2)
@@ -243,7 +264,7 @@ func TestMerge(t *testing.T) {
243264
}
244265
fsys, err := fs.Sub(mapFS, "migrations")
245266
check.NoError(t, err)
246-
sources, err := collectFilesystemSources(fsys, false, nil, nil)
267+
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
247268
check.NoError(t, err)
248269
t.Run("unregistered_all", func(t *testing.T) {
249270
migrations, err := merge(sources, map[int64]*Migration{
@@ -267,7 +288,7 @@ func TestMerge(t *testing.T) {
267288
}
268289
fsys, err := fs.Sub(mapFS, "migrations")
269290
check.NoError(t, err)
270-
sources, err := collectFilesystemSources(fsys, false, nil, nil)
291+
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
271292
check.NoError(t, err)
272293
t.Run("unregistered_all", func(t *testing.T) {
273294
migrations, err := merge(sources, map[int64]*Migration{

provider_options.go

+8
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ func WithDisableVersioning(b bool) ProviderOption {
165165
})
166166
}
167167

168+
func WithRecursive(b bool) ProviderOption {
169+
return configFunc(func(c *config) error {
170+
c.recursive = b
171+
return nil
172+
})
173+
}
174+
168175
type config struct {
169176
store database.Store
170177

@@ -184,6 +191,7 @@ type config struct {
184191
disableVersioning bool
185192
allowMissing bool
186193
disableGlobalRegistry bool
194+
recursive bool
187195

188196
// Let's not expose the Logger just yet. Ideally we consolidate on the std lib slog package
189197
// added in go1.21 and then expose that (if that's even necessary). For now, just use the std

0 commit comments

Comments
 (0)