Skip to content

Commit 0f032ae

Browse files
committed
store db migrates in the database if it supports it
1 parent 67ef559 commit 0f032ae

File tree

4 files changed

+513
-0
lines changed

4 files changed

+513
-0
lines changed

database/postgres/storage.go

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
package postgres
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"io"
8+
9+
"github.com/golang-migrate/migrate/v4/database"
10+
"github.com/golang-migrate/migrate/v4/source"
11+
"github.com/lib/pq"
12+
)
13+
14+
// Ensure Postgres implements MigrationStorageDriver
15+
var _ database.MigrationStorageDriver = &Postgres{}
16+
17+
// ensureEnhancedVersionTable checks if the enhanced versions table exists and creates/updates it.
18+
// This version includes columns for storing migration scripts.
19+
func (p *Postgres) ensureEnhancedVersionTable() (err error) {
20+
if err = p.Lock(); err != nil {
21+
return err
22+
}
23+
24+
defer func() {
25+
if e := p.Unlock(); e != nil {
26+
if err == nil {
27+
err = e
28+
} else {
29+
err = fmt.Errorf("unlock error: %v, original error: %v", e, err)
30+
}
31+
}
32+
}()
33+
34+
exists, err := p.tableExists()
35+
if err != nil {
36+
return err
37+
}
38+
39+
if !exists {
40+
return p.createEnhancedTable()
41+
}
42+
43+
return p.addMissingColumns()
44+
}
45+
46+
// tableExists checks if the migrations table exists
47+
func (p *Postgres) tableExists() (bool, error) {
48+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
49+
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
50+
51+
var count int
52+
err := row.Scan(&count)
53+
if err != nil {
54+
return false, &database.Error{OrigErr: err, Query: []byte(query)}
55+
}
56+
57+
return count > 0, nil
58+
}
59+
60+
// createEnhancedTable creates the migrations table with all required columns
61+
func (p *Postgres) createEnhancedTable() error {
62+
query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (
63+
version bigint not null primary key,
64+
dirty boolean not null,
65+
up_script text,
66+
down_script text,
67+
created_at timestamp with time zone default now()
68+
)`
69+
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
70+
return &database.Error{OrigErr: err, Query: []byte(query)}
71+
}
72+
return nil
73+
}
74+
75+
// addMissingColumns adds any missing columns to existing table
76+
func (p *Postgres) addMissingColumns() error {
77+
columns := []string{"up_script", "down_script", "created_at"}
78+
79+
for _, column := range columns {
80+
exists, err := p.columnExists(column)
81+
if err != nil {
82+
return err
83+
}
84+
85+
if !exists {
86+
if err := p.addColumn(column); err != nil {
87+
return err
88+
}
89+
}
90+
}
91+
92+
return nil
93+
}
94+
95+
// columnExists checks if a specific column exists in the migrations table
96+
func (p *Postgres) columnExists(columnName string) (bool, error) {
97+
query := `SELECT COUNT(1) FROM information_schema.columns
98+
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3`
99+
100+
var count int
101+
err := p.conn.QueryRowContext(context.Background(), query,
102+
p.config.migrationsSchemaName, p.config.migrationsTableName, columnName).Scan(&count)
103+
if err != nil {
104+
return false, &database.Error{OrigErr: err, Query: []byte(query)}
105+
}
106+
107+
return count > 0, nil
108+
}
109+
110+
// addColumn adds a specific column to the migrations table
111+
func (p *Postgres) addColumn(columnName string) error {
112+
var columnDef string
113+
switch columnName {
114+
case "up_script", "down_script":
115+
columnDef = "text"
116+
case "created_at":
117+
columnDef = "timestamp with time zone default now()"
118+
default:
119+
return fmt.Errorf("unknown column: %s", columnName)
120+
}
121+
122+
alterQuery := `ALTER TABLE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` +
123+
pq.QuoteIdentifier(p.config.migrationsTableName) + ` ADD COLUMN ` + columnName + ` ` + columnDef
124+
if _, err := p.conn.ExecContext(context.Background(), alterQuery); err != nil {
125+
return &database.Error{OrigErr: err, Query: []byte(alterQuery)}
126+
}
127+
return nil
128+
}
129+
130+
// StoreMigration stores the up and down migration scripts for a given version
131+
func (p *Postgres) StoreMigration(version uint, upScript, downScript []byte) error {
132+
// Ensure the enhanced table exists
133+
if err := p.ensureEnhancedVersionTable(); err != nil {
134+
return err
135+
}
136+
137+
query := `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` +
138+
pq.QuoteIdentifier(p.config.migrationsTableName) +
139+
` (version, dirty, up_script, down_script) VALUES ($1, false, $2, $3)
140+
ON CONFLICT (version) DO UPDATE SET
141+
up_script = EXCLUDED.up_script,
142+
down_script = EXCLUDED.down_script,
143+
created_at = now()`
144+
145+
_, err := p.conn.ExecContext(context.Background(), query, int64(version), string(upScript), string(downScript))
146+
if err != nil {
147+
return &database.Error{OrigErr: err, Query: []byte(query)}
148+
}
149+
150+
return nil
151+
}
152+
153+
// GetMigration retrieves the stored migration scripts for a given version
154+
func (p *Postgres) GetMigration(version uint) (upScript, downScript []byte, err error) {
155+
query := `SELECT up_script, down_script FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` +
156+
pq.QuoteIdentifier(p.config.migrationsTableName) + ` WHERE version = $1`
157+
158+
var upScriptStr, downScriptStr sql.NullString
159+
err = p.conn.QueryRowContext(context.Background(), query, int64(version)).Scan(&upScriptStr, &downScriptStr)
160+
if err != nil {
161+
if err == sql.ErrNoRows {
162+
return nil, nil, fmt.Errorf("migration version %d not found", version)
163+
}
164+
return nil, nil, &database.Error{OrigErr: err, Query: []byte(query)}
165+
}
166+
167+
if upScriptStr.Valid {
168+
upScript = []byte(upScriptStr.String)
169+
}
170+
if downScriptStr.Valid {
171+
downScript = []byte(downScriptStr.String)
172+
}
173+
174+
return upScript, downScript, nil
175+
}
176+
177+
// GetStoredMigrations returns all migration versions that have scripts stored
178+
func (p *Postgres) GetStoredMigrations() ([]uint, error) {
179+
query := `SELECT version FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` +
180+
pq.QuoteIdentifier(p.config.migrationsTableName) +
181+
` WHERE up_script IS NOT NULL OR down_script IS NOT NULL ORDER BY version ASC`
182+
183+
rows, err := p.conn.QueryContext(context.Background(), query)
184+
if err != nil {
185+
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
186+
}
187+
defer rows.Close()
188+
189+
var versions []uint
190+
for rows.Next() {
191+
var version int64
192+
if err := rows.Scan(&version); err != nil {
193+
return nil, err
194+
}
195+
versions = append(versions, uint(version))
196+
}
197+
198+
return versions, rows.Err()
199+
}
200+
201+
// SyncMigrations ensures all available migrations up to maxVersion are stored in the database
202+
func (p *Postgres) SyncMigrations(sourceDriver interface{}, maxVersion uint) error {
203+
srcDriver, ok := sourceDriver.(source.Driver)
204+
if !ok {
205+
return fmt.Errorf("source driver must implement source.Driver interface")
206+
}
207+
208+
versions, err := p.collectVersions(srcDriver, maxVersion)
209+
if err != nil {
210+
return err
211+
}
212+
213+
return p.storeMigrations(srcDriver, versions)
214+
}
215+
216+
// collectVersions gets all migration versions up to maxVersion
217+
func (p *Postgres) collectVersions(srcDriver source.Driver, maxVersion uint) ([]uint, error) {
218+
first, err := srcDriver.First()
219+
if err != nil {
220+
return nil, fmt.Errorf("failed to get first migration: %w", err)
221+
}
222+
223+
var versions []uint
224+
currentVersion := first
225+
226+
for currentVersion <= maxVersion {
227+
versions = append(versions, currentVersion)
228+
229+
next, err := srcDriver.Next(currentVersion)
230+
if err != nil {
231+
if err.Error() == "file does not exist" { // Handle os.ErrNotExist
232+
break
233+
}
234+
return nil, fmt.Errorf("failed to get next migration after %d: %w", currentVersion, err)
235+
}
236+
currentVersion = next
237+
}
238+
239+
return versions, nil
240+
}
241+
242+
// storeMigrations reads and stores migration scripts for the given versions
243+
func (p *Postgres) storeMigrations(srcDriver source.Driver, versions []uint) error {
244+
for _, version := range versions {
245+
upScript, err := p.readMigrationScript(srcDriver, version, true)
246+
if err != nil {
247+
return err
248+
}
249+
250+
downScript, err := p.readMigrationScript(srcDriver, version, false)
251+
if err != nil {
252+
return err
253+
}
254+
255+
// Store the migration if we have at least one script
256+
if len(upScript) > 0 || len(downScript) > 0 {
257+
if err := p.StoreMigration(version, upScript, downScript); err != nil {
258+
return fmt.Errorf("failed to store migration %d: %w", version, err)
259+
}
260+
}
261+
}
262+
263+
return nil
264+
}
265+
266+
// readMigrationScript reads a migration script (up or down) for a given version
267+
func (p *Postgres) readMigrationScript(srcDriver source.Driver, version uint, isUp bool) ([]byte, error) {
268+
var reader io.ReadCloser
269+
var err error
270+
271+
if isUp {
272+
reader, _, err = srcDriver.ReadUp(version)
273+
} else {
274+
reader, _, err = srcDriver.ReadDown(version)
275+
}
276+
277+
if err != nil {
278+
// It's OK if migration doesn't exist
279+
return nil, nil
280+
}
281+
282+
defer reader.Close()
283+
script, err := io.ReadAll(reader)
284+
if err != nil {
285+
direction := "up"
286+
if !isUp {
287+
direction = "down"
288+
}
289+
return nil, fmt.Errorf("failed to read %s migration %d: %w", direction, version, err)
290+
}
291+
292+
return script, nil
293+
}

database/storage.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package database
2+
3+
// MigrationStorageDriver extends the basic Driver interface to support
4+
// storing and retrieving migration scripts in the database itself.
5+
// This is useful for dirty state handling when shared storage isn't available.
6+
type MigrationStorageDriver interface {
7+
Driver
8+
9+
// StoreMigration stores the up and down migration scripts for a given version
10+
// in the database. This allows for dirty state recovery without external files.
11+
StoreMigration(version uint, upScript, downScript []byte) error
12+
13+
// GetMigration retrieves the stored migration scripts for a given version.
14+
// Returns the up and down scripts, or an error if the version doesn't exist.
15+
GetMigration(version uint) (upScript, downScript []byte, err error)
16+
17+
// GetStoredMigrations returns all migration versions that have scripts stored
18+
// in the database, sorted in ascending order.
19+
GetStoredMigrations() ([]uint, error)
20+
21+
// SyncMigrations ensures all available migrations up to maxVersion are stored
22+
// in the database. This should be called during migration runs to keep
23+
// the database in sync with available migration files.
24+
SyncMigrations(sourceDriver interface{}, maxVersion uint) error
25+
}
26+
27+
// SupportsStorage checks if a driver supports migration script storage
28+
func SupportsStorage(driver Driver) bool {
29+
_, ok := driver.(MigrationStorageDriver)
30+
return ok
31+
}

0 commit comments

Comments
 (0)