|
| 1 | +// Package migrations provides an embedded migration runner that applies |
| 2 | +// SQL migration files from an embed.FS to CockroachDB databases. |
| 3 | +// |
| 4 | +// It handles database and user provisioning, migration tracking via |
| 5 | +// a _meridian_migrations table, and idempotent re-runs. |
| 6 | +package migrations |
| 7 | + |
| 8 | +import ( |
| 9 | + "context" |
| 10 | + "errors" |
| 11 | + "fmt" |
| 12 | + "io/fs" |
| 13 | + "log/slog" |
| 14 | + "net/url" |
| 15 | + "sort" |
| 16 | + "strings" |
| 17 | + "time" |
| 18 | + |
| 19 | + "github.com/jackc/pgx/v5" |
| 20 | +) |
| 21 | + |
| 22 | +// ErrUnknownService is returned when a migration file belongs to a service |
| 23 | +// that has no entry in ServiceDatabases. |
| 24 | +var ErrUnknownService = errors.New("unknown service: no database mapping") |
| 25 | + |
| 26 | +// ServiceDatabase maps a service directory name to its target database, user, and password. |
| 27 | +type ServiceDatabase struct { |
| 28 | + Database string |
| 29 | + User string |
| 30 | + Password string |
| 31 | +} |
| 32 | + |
| 33 | +// ServiceDatabases defines the mapping from service directory names to |
| 34 | +// CockroachDB database names, users, and passwords. |
| 35 | +// |
| 36 | +// Two services (tenant, control-plane) share meridian_platform. |
| 37 | +// Their migrations are applied in service-name order (control-plane before tenant). |
| 38 | +var ServiceDatabases = map[string]ServiceDatabase{ |
| 39 | + "control-plane": {Database: "meridian_platform", User: "meridian_platform_user", Password: ""}, |
| 40 | + "tenant": {Database: "meridian_platform", User: "meridian_platform_user", Password: ""}, |
| 41 | + "current-account": {Database: "meridian_current_account", User: "meridian_current_account_user", Password: ""}, |
| 42 | + "financial-accounting": {Database: "meridian_financial_accounting", User: "meridian_financial_accounting_user", Password: ""}, |
| 43 | + "position-keeping": {Database: "meridian_position_keeping", User: "meridian_position_keeping_user", Password: ""}, |
| 44 | + "payment-order": {Database: "meridian_payment_order", User: "meridian_payment_order_user", Password: ""}, |
| 45 | + "party": {Database: "meridian_party", User: "meridian_party_user", Password: ""}, |
| 46 | + "internal-bank-account": {Database: "meridian_internal_bank_account", User: "meridian_internal_bank_account_user", Password: ""}, |
| 47 | + "market-information": {Database: "meridian_market_information", User: "meridian_market_information_user", Password: ""}, |
| 48 | + "reconciliation": {Database: "meridian_reconciliation", User: "meridian_reconciliation_user", Password: ""}, |
| 49 | + "forecasting": {Database: "meridian_forecasting", User: "meridian_forecasting_user", Password: ""}, |
| 50 | + "reference-data": {Database: "meridian_reference_data", User: "meridian_reference_data_user", Password: ""}, |
| 51 | +} |
| 52 | + |
| 53 | +// serviceMigration holds a single migration file for a service. |
| 54 | +type serviceMigration struct { |
| 55 | + Service string |
| 56 | + Filename string |
| 57 | + SQL string |
| 58 | +} |
| 59 | + |
| 60 | +// RunMigrations discovers migration files from the provided embed.FS, provisions |
| 61 | +// databases and users as superuser, then applies unapplied migrations in order. |
| 62 | +// |
| 63 | +// The superuserDSN should connect to CockroachDB as a privileged user (e.g., root) |
| 64 | +// capable of CREATE DATABASE, CREATE USER, and GRANT operations. |
| 65 | +// |
| 66 | +// Migration state is tracked per-database in a _meridian_migrations table. |
| 67 | +// Running this function multiple times is safe (idempotent). |
| 68 | +func RunMigrations(ctx context.Context, migrationFS fs.FS, superuserDSN string, logger *slog.Logger) error { |
| 69 | + migrations, err := discoverMigrations(migrationFS) |
| 70 | + if err != nil { |
| 71 | + return fmt.Errorf("discover migrations: %w", err) |
| 72 | + } |
| 73 | + |
| 74 | + if len(migrations) == 0 { |
| 75 | + logger.Info("no migrations found") |
| 76 | + return nil |
| 77 | + } |
| 78 | + |
| 79 | + // Collect unique databases to provision. |
| 80 | + dbSet := make(map[string]ServiceDatabase) |
| 81 | + for _, m := range migrations { |
| 82 | + sdb, ok := ServiceDatabases[m.Service] |
| 83 | + if !ok { |
| 84 | + return fmt.Errorf("service %q: %w", m.Service, ErrUnknownService) |
| 85 | + } |
| 86 | + dbSet[sdb.Database] = sdb |
| 87 | + } |
| 88 | + |
| 89 | + // Connect as superuser to provision databases and users, then close. |
| 90 | + if err := provisionAll(ctx, superuserDSN, dbSet, logger); err != nil { |
| 91 | + return err |
| 92 | + } |
| 93 | + |
| 94 | + // Group migrations by target database. |
| 95 | + byDB := groupByDatabase(migrations) |
| 96 | + |
| 97 | + // Apply migrations to each database. |
| 98 | + for dbName, dbMigrations := range byDB { |
| 99 | + sdb := dbMigrations[0].sdb |
| 100 | + dsn := buildServiceDSN(superuserDSN, sdb) |
| 101 | + |
| 102 | + if err := applyDatabaseMigrations(ctx, dsn, dbName, dbMigrations, logger); err != nil { |
| 103 | + return fmt.Errorf("apply migrations to %s: %w", dbName, err) |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + return nil |
| 108 | +} |
| 109 | + |
| 110 | +// discoverMigrations reads migration SQL files from the embedded filesystem. |
| 111 | +// It expects paths of the form: <service>/migrations/<filename>.sql |
| 112 | +func discoverMigrations(migrationFS fs.FS) ([]serviceMigration, error) { |
| 113 | + var migrations []serviceMigration |
| 114 | + |
| 115 | + err := fs.WalkDir(migrationFS, ".", func(path string, d fs.DirEntry, err error) error { |
| 116 | + if err != nil { |
| 117 | + return err |
| 118 | + } |
| 119 | + if d.IsDir() { |
| 120 | + return nil |
| 121 | + } |
| 122 | + if !strings.HasSuffix(path, ".sql") { |
| 123 | + return nil |
| 124 | + } |
| 125 | + |
| 126 | + // Expected path: <service>/migrations/<filename>.sql |
| 127 | + parts := strings.Split(path, "/") |
| 128 | + if len(parts) != 3 || parts[1] != "migrations" { |
| 129 | + return nil |
| 130 | + } |
| 131 | + |
| 132 | + service := parts[0] |
| 133 | + filename := parts[2] |
| 134 | + |
| 135 | + content, err := fs.ReadFile(migrationFS, path) |
| 136 | + if err != nil { |
| 137 | + return fmt.Errorf("read %s: %w", path, err) |
| 138 | + } |
| 139 | + |
| 140 | + migrations = append(migrations, serviceMigration{ |
| 141 | + Service: service, |
| 142 | + Filename: filename, |
| 143 | + SQL: string(content), |
| 144 | + }) |
| 145 | + return nil |
| 146 | + }) |
| 147 | + if err != nil { |
| 148 | + return nil, err |
| 149 | + } |
| 150 | + |
| 151 | + // Sort by service name then filename for deterministic ordering. |
| 152 | + sort.Slice(migrations, func(i, j int) bool { |
| 153 | + if migrations[i].Service != migrations[j].Service { |
| 154 | + return migrations[i].Service < migrations[j].Service |
| 155 | + } |
| 156 | + return migrations[i].Filename < migrations[j].Filename |
| 157 | + }) |
| 158 | + |
| 159 | + return migrations, nil |
| 160 | +} |
| 161 | + |
| 162 | +// provisionAll connects as superuser, provisions databases, and closes the connection. |
| 163 | +func provisionAll(ctx context.Context, superuserDSN string, databases map[string]ServiceDatabase, logger *slog.Logger) error { |
| 164 | + superConn, err := pgx.Connect(ctx, superuserDSN) |
| 165 | + if err != nil { |
| 166 | + return fmt.Errorf("connect as superuser: %w", err) |
| 167 | + } |
| 168 | + defer func() { _ = superConn.Close(ctx) }() |
| 169 | + |
| 170 | + return provisionDatabases(ctx, superConn, databases, logger) |
| 171 | +} |
| 172 | + |
| 173 | +// provisionDatabases creates databases and users as needed. |
| 174 | +// Each DDL statement is executed individually because pgx v5's extended |
| 175 | +// protocol does not support multi-statement query strings. |
| 176 | +func provisionDatabases(ctx context.Context, conn *pgx.Conn, databases map[string]ServiceDatabase, logger *slog.Logger) error { |
| 177 | + for dbName, sdb := range databases { |
| 178 | + logger.Info("provisioning database", "database", dbName, "user", sdb.User) |
| 179 | + |
| 180 | + stmts := []string{ |
| 181 | + fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteIdent(dbName)), |
| 182 | + fmt.Sprintf("CREATE USER IF NOT EXISTS %s", quoteIdent(sdb.User)), |
| 183 | + fmt.Sprintf("GRANT ALL ON DATABASE %s TO %s", quoteIdent(dbName), quoteIdent(sdb.User)), |
| 184 | + } |
| 185 | + for _, stmt := range stmts { |
| 186 | + if _, err := conn.Exec(ctx, stmt); err != nil { |
| 187 | + return fmt.Errorf("provision %s: %w", dbName, err) |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + return nil |
| 192 | +} |
| 193 | + |
| 194 | +type dbMigration struct { |
| 195 | + sdb ServiceDatabase |
| 196 | + service string |
| 197 | + filename string |
| 198 | + sql string |
| 199 | +} |
| 200 | + |
| 201 | +// groupByDatabase groups migrations by their target database name. |
| 202 | +// Within each database, migrations are sorted by filename (lexicographic). |
| 203 | +func groupByDatabase(migrations []serviceMigration) map[string][]dbMigration { |
| 204 | + result := make(map[string][]dbMigration) |
| 205 | + |
| 206 | + for _, m := range migrations { |
| 207 | + sdb := ServiceDatabases[m.Service] |
| 208 | + result[sdb.Database] = append(result[sdb.Database], dbMigration{ |
| 209 | + sdb: sdb, |
| 210 | + service: m.Service, |
| 211 | + filename: m.Filename, |
| 212 | + sql: m.SQL, |
| 213 | + }) |
| 214 | + } |
| 215 | + |
| 216 | + // Sort each database's migrations by service then filename. |
| 217 | + for _, dbMigs := range result { |
| 218 | + sort.Slice(dbMigs, func(i, j int) bool { |
| 219 | + if dbMigs[i].service != dbMigs[j].service { |
| 220 | + return dbMigs[i].service < dbMigs[j].service |
| 221 | + } |
| 222 | + return dbMigs[i].filename < dbMigs[j].filename |
| 223 | + }) |
| 224 | + } |
| 225 | + |
| 226 | + return result |
| 227 | +} |
| 228 | + |
| 229 | +// applyDatabaseMigrations connects to a specific database and applies unapplied migrations. |
| 230 | +func applyDatabaseMigrations(ctx context.Context, dsn, dbName string, migrations []dbMigration, logger *slog.Logger) error { |
| 231 | + conn, err := pgx.Connect(ctx, dsn) |
| 232 | + if err != nil { |
| 233 | + return fmt.Errorf("connect to %s: %w", dbName, err) |
| 234 | + } |
| 235 | + defer func() { _ = conn.Close(ctx) }() |
| 236 | + |
| 237 | + if err := ensureMigrationsTable(ctx, conn); err != nil { |
| 238 | + return fmt.Errorf("create tracking table: %w", err) |
| 239 | + } |
| 240 | + |
| 241 | + applied, err := getAppliedMigrations(ctx, conn) |
| 242 | + if err != nil { |
| 243 | + return fmt.Errorf("read applied migrations: %w", err) |
| 244 | + } |
| 245 | + |
| 246 | + for _, m := range migrations { |
| 247 | + key := m.service + "/" + m.filename |
| 248 | + if applied[key] { |
| 249 | + logger.Debug("skipping already applied migration", "database", dbName, "service", m.service, "file", m.filename) |
| 250 | + continue |
| 251 | + } |
| 252 | + |
| 253 | + logger.Info("applying migration", "database", dbName, "service", m.service, "file", m.filename) |
| 254 | + |
| 255 | + if _, err := conn.Exec(ctx, m.sql); err != nil { |
| 256 | + return fmt.Errorf("execute %s/%s: %w", m.service, m.filename, err) |
| 257 | + } |
| 258 | + |
| 259 | + if err := recordMigration(ctx, conn, m.service, m.filename); err != nil { |
| 260 | + return fmt.Errorf("record %s/%s: %w", m.service, m.filename, err) |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + return nil |
| 265 | +} |
| 266 | + |
| 267 | +// ensureMigrationsTable creates the _meridian_migrations tracking table if it does not exist. |
| 268 | +func ensureMigrationsTable(ctx context.Context, conn *pgx.Conn) error { |
| 269 | + _, err := conn.Exec(ctx, ` |
| 270 | + CREATE TABLE IF NOT EXISTS _meridian_migrations ( |
| 271 | + id INT8 NOT NULL DEFAULT unique_rowid(), |
| 272 | + service VARCHAR(255) NOT NULL, |
| 273 | + filename VARCHAR(255) NOT NULL, |
| 274 | + applied_at TIMESTAMPTZ NOT NULL DEFAULT now(), |
| 275 | + PRIMARY KEY (id), |
| 276 | + UNIQUE (service, filename) |
| 277 | + ) |
| 278 | + `) |
| 279 | + return err |
| 280 | +} |
| 281 | + |
| 282 | +// getAppliedMigrations returns a set of "service/filename" keys for already-applied migrations. |
| 283 | +func getAppliedMigrations(ctx context.Context, conn *pgx.Conn) (map[string]bool, error) { |
| 284 | + rows, err := conn.Query(ctx, `SELECT service, filename FROM _meridian_migrations`) |
| 285 | + if err != nil { |
| 286 | + return nil, err |
| 287 | + } |
| 288 | + defer rows.Close() |
| 289 | + |
| 290 | + applied := make(map[string]bool) |
| 291 | + for rows.Next() { |
| 292 | + var service, filename string |
| 293 | + if err := rows.Scan(&service, &filename); err != nil { |
| 294 | + return nil, err |
| 295 | + } |
| 296 | + applied[service+"/"+filename] = true |
| 297 | + } |
| 298 | + return applied, rows.Err() |
| 299 | +} |
| 300 | + |
| 301 | +// recordMigration inserts a record into _meridian_migrations for a successfully applied migration. |
| 302 | +func recordMigration(ctx context.Context, conn *pgx.Conn, service, filename string) error { |
| 303 | + _, err := conn.Exec(ctx, |
| 304 | + `INSERT INTO _meridian_migrations (service, filename, applied_at) VALUES ($1, $2, $3)`, |
| 305 | + service, filename, time.Now(), |
| 306 | + ) |
| 307 | + return err |
| 308 | +} |
| 309 | + |
| 310 | +// buildServiceDSN modifies a superuser DSN to target a specific database and user. |
| 311 | +// It parses the URL, replaces user/database, and preserves all query parameters |
| 312 | +// (TLS settings, timeouts, etc.). It also sets simple_protocol exec mode so that |
| 313 | +// multi-statement migration files can be executed in a single Exec() call. |
| 314 | +func buildServiceDSN(superuserDSN string, sdb ServiceDatabase) string { |
| 315 | + parsed, err := url.Parse(superuserDSN) |
| 316 | + if err != nil { |
| 317 | + return superuserDSN |
| 318 | + } |
| 319 | + |
| 320 | + // Replace user credentials. |
| 321 | + if sdb.Password != "" { |
| 322 | + parsed.User = url.UserPassword(sdb.User, sdb.Password) |
| 323 | + } else { |
| 324 | + parsed.User = url.User(sdb.User) |
| 325 | + } |
| 326 | + |
| 327 | + // Replace database in path (postgres://user@host:port/database). |
| 328 | + parsed.Path = "/" + sdb.Database |
| 329 | + |
| 330 | + // Ensure default port for CockroachDB if not specified. |
| 331 | + if parsed.Port() == "" { |
| 332 | + parsed.Host = parsed.Hostname() + ":26257" |
| 333 | + } |
| 334 | + |
| 335 | + // Enable simple protocol so multi-statement migration SQL files work with pgx v5. |
| 336 | + q := parsed.Query() |
| 337 | + if q.Get("default_query_exec_mode") == "" { |
| 338 | + q.Set("default_query_exec_mode", "simple_protocol") |
| 339 | + } |
| 340 | + parsed.RawQuery = q.Encode() |
| 341 | + |
| 342 | + return parsed.String() |
| 343 | +} |
| 344 | + |
| 345 | +// quoteIdent wraps a SQL identifier in double quotes, escaping any embedded double quotes. |
| 346 | +func quoteIdent(s string) string { |
| 347 | + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` |
| 348 | +} |
0 commit comments