Skip to content
Open
67 changes: 58 additions & 9 deletions internal/db/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strings"

"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
Expand All @@ -25,31 +27,43 @@ const (
DISABLE_PGTAP = "drop extension if exists pgtap"
)

var irPattern = regexp.MustCompile(`(?im)^\s*\\ir\s+['"]?([^'"\s]+)['"]?`)

func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
// Build test command
if len(testFiles) == 0 {
absTestsDir, err := filepath.Abs(utils.DbTestsDir)
if err != nil {
return errors.Errorf("failed to resolve tests dir: %w", err)
}
testFiles = append(testFiles, absTestsDir)
testFiles = append(testFiles, utils.DbTestsDir)
}
allFiles, err := traverseImports(testFiles, fsys)
if err != nil {
return err
}
binds := make([]string, len(testFiles))
testFileSet := make(map[string]struct{}, len(testFiles))
for _, tf := range testFiles {
testFileSet[tf] = struct{}{}
}
binds := make([]string, len(allFiles))
cmd := []string{"pg_prove", "--ext", ".pg", "--ext", ".sql", "-r"}
var workingDir string
for i, fp := range testFiles {
for i, fp := range allFiles {
if !filepath.IsAbs(fp) {
fp = filepath.Join(utils.CurrentDirAbs, fp)
}
dockerPath := utils.ToDockerPath(fp)
cmd = append(cmd, dockerPath)
binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath)
if workingDir == "" {
workingDir = dockerPath
if path.Ext(dockerPath) != "" {
workingDir = path.Dir(dockerPath)
}
}
if _, isTestFile := testFileSet[allFiles[i]]; isTestFile {
relPath := dockerPath
if path.Ext(dockerPath) != "" && path.Dir(dockerPath) == workingDir {
relPath = path.Base(dockerPath)
}
cmd = append(cmd, relPath)
}
binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath)
}
if viper.GetBool("DEBUG") {
cmd = append(cmd, "--verbose")
Expand Down Expand Up @@ -107,3 +121,38 @@ func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afe
os.Stderr,
)
}

func traverseImports(testFiles []string, fsys afero.Fs) ([]string, error) {
seen := map[string]struct{}{}
q := append([]string{}, testFiles...)
result := []string{}
for len(q) > 0 {
curr := q[len(q)-1]
q = q[:len(q)-1]
if _, ok := seen[curr]; ok {
continue
}
seen[curr] = struct{}{}
result = append(result, curr)
info, err := fsys.Stat(curr)
if err != nil {
return nil, errors.Errorf("failed to stat %s: %w", curr, err)
}
if info.IsDir() {
continue
}
data, err := afero.ReadFile(fsys, curr)
if err != nil {
return nil, errors.Errorf("failed to read %s: %w", curr, err)
}
for _, m := range irPattern.FindAllStringSubmatch(string(data), -1) {
if len(m) < 2 {
continue
}
importPath := strings.TrimSpace(m[1])
resolved := filepath.Join(filepath.Dir(curr), importPath)
q = append(q, resolved)
}
}
return result, nil
}
40 changes: 40 additions & 0 deletions internal/db/test/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, afero.WriteFile(fsys, "nested", []byte("SELECT 1;"), 0644))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -53,6 +54,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Run test
err := Run(context.Background(), nil, dbConfig, fsys)
// Check error
Expand All @@ -63,6 +65,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -79,6 +82,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -99,3 +103,39 @@ func TestRunCommand(t *testing.T) {
assert.Empty(t, apitest.ListUnmatchedRequests())
})
}

func TestTraverseImports(t *testing.T) {
t.Run("handles file with \\ir import", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir helper.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "helper.sql", []byte("SELECT 1;"), 0644))

result, err := traverseImports([]string{"main.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 2)
})

t.Run("handles nested \\ir imports", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir level1.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "level1.sql", []byte("\\ir level2.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "level2.sql", []byte("SELECT 1;"), 0644))

result, err := traverseImports([]string{"main.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 3)
})

t.Run("handles circular imports", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "a.sql", []byte("\\ir b.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "b.sql", []byte("\\ir a.sql"), 0644))

result, err := traverseImports([]string{"a.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 2)
})
}
Loading