diff --git a/README.md b/README.md index 7af27c7..202b950 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Prometheus instrumentation, HTTP client/server utilities, and more. - `pkg/log`: Structured logging, audit logs, HTTP request/response logging - `pkg/logging`: Global logger facade for convenience - `pkg/middlewares`: Auth, tenant, tracing, URL filter middlewares -- `pkg/migrate`: Migration runner for PostgreSQL (SQL files) +- `pkg/migrate`: Migration runner for PostgreSQL (SQL files). See [Documentation](docs/migration/README.md) - `pkg/prom`: Prometheus metrics utilities for HTTP client/server - `pkg/standard`: Opinionated server/gateway wiring - `pkg/ticket`: Lightweight JWT ticket verification/claims diff --git a/docs/migration/README.md b/docs/migration/README.md new file mode 100644 index 0000000..bff0e60 --- /dev/null +++ b/docs/migration/README.md @@ -0,0 +1,90 @@ +# Migration Approaches + +This repository supports **two migration flows** for services using `pkg/db`. +Choose the versioned flow if you need per‑version interleaving, or the legacy +flow for a single AutoMigrate pass. + +## 1) Versioned Migration Flow + +Use when you need **interleaving** per migration version. + +**API** +- `WithVersionedMigrationFunc(func(*gorm.DB, uint) error)` +- Optional: `WithMigrationVersion(version)` + +**Behavior** +- Runs `setup.sql` once (idempotent). +- Runs `fdw.up.sql` / `fdw.down.sql` once (optional). +- For each version from current+1 to target: + 1. Versioned before script (optional) + 2. `AutoMigrate(version)` (service implementation) + 3. Versioned after script (optional) + 4. Record version **after** the full sequence completes + +**Notes** +- Missing before/after scripts are skipped. +- The version bump represents completion of before + AutoMigrate + after. +- **Recording a version** means updating the `migrations` table used by + `golang-migrate` with the new version number. +- **Supported naming (versioned flow only):** + - Before: `{version}_{name}.before.sql` **or** `{version}_{name}.before.up.sql` + - After: `{version}_{name}.after.sql` **or** `{version}_{name}.after.up.sql` + +**Diagram** +```mermaid +flowchart TB + %% Versioned Migration Logic + + B0["Start migrations"] --> B1["setup.sql (optional, idempotent)"] + B1 --> B2["fdw.up.sql (optional)"] + B2 --> V1["For each version v = current+1 .. target"] + V1 --> V2["{version}_{name}.before.sql or .before.up.sql (optional)"] + V2 --> V3["AutoMigrate(version)"] + V3 --> V4["{version}_{name}.after.sql or .after.up.sql (optional)"] + V4 --> V5["Record version v (migrations table)"] + V5 -.-> V1 + V5 --> B3["fdw.down.sql (optional)"] +``` + +## 2) Legacy Migration Flow + +Use when you want a single AutoMigrate call and minimal changes to existing +services. + +**API** +- `WithMigrationFunc(func(*gorm.DB) error)` +- Optional: `WithMigrationVersion(version)` + +**Behavior** +- Runs **one** AutoMigrate (latest models). +- Runs `setup.sql` (idempotent), then `fdw.up.sql` (optional). +- Runs numbered SQL migrations via `golang-migrate`: + - `{version}_{name}.up.sql` / `{version}_{name}.down.sql` +- Runs `fdw.down.sql` (optional). + +**Diagram** +```mermaid +flowchart TB + %% Legacy Migration Logic + + A0["Start migrations"] --> A1["AutoMigrate (once, latest models)"] + A1 --> A2["setup.sql (optional, idempotent)"] + A2 --> A3["fdw.up.sql (optional)"] + A3 --> L1["For each version v = current+1 .. target"] + L1 --> L2["{version}_{name}.up.sql"] + L2 -.-> L1 + L2 --> A4["fdw.down.sql (optional)"] +``` + +## File Naming Summary + +Versioned flow: +- Before: `{version}_{name}.before.sql` or `{version}_{name}.before.up.sql` +- After: `{version}_{name}.after.sql` or `{version}_{name}.after.up.sql` + +Legacy flow: +- SQL migrations: `{version}_{name}.up.sql` + +Shared: +- Setup: `setup.sql` +- FDW: `fdw.up.sql`, `fdw.down.sql` diff --git a/docs/migration/migration-legacy-example-v5-to-v7.mmd b/docs/migration/migration-legacy-example-v5-to-v7.mmd new file mode 100644 index 0000000..2b3e589 --- /dev/null +++ b/docs/migration/migration-legacy-example-v5-to-v7.mmd @@ -0,0 +1,10 @@ +flowchart TB + %% Legacy Example: Current v5 -> Target v7 + + A0["Current DB version: 5"] --> A1["Target version: 7"] + A1 --> A2["AutoMigrate (once, latest models)"] + A2 --> A3["setup.sql (optional, idempotent)"] + A3 --> A4["fdw.up.sql (optional)"] + A4 --> A5["006_fix_constraint.up.sql"] + A5 --> A6["007_add_feature.up.sql"] + A6 --> A7["fdw.down.sql (optional)"] diff --git a/docs/migration/migration-legacy.mmd b/docs/migration/migration-legacy.mmd new file mode 100644 index 0000000..98e0709 --- /dev/null +++ b/docs/migration/migration-legacy.mmd @@ -0,0 +1,10 @@ +flowchart TB + %% Legacy Migration Logic (as on main) + + A0["Start migrations"] --> A1["AutoMigrate (once, latest models)"] + A1 --> A2["setup.sql (optional, idempotent)"] + A2 --> A3["fdw.up.sql (optional)"] + A3 --> L1["For each version v = current+1 .. target"] + L1 --> L2["{version}_{name}.up.sql"] + L2 -.-> L1 + L2 --> A4["fdw.down.sql (optional)"] diff --git a/docs/migration/migration-versioned-example-v5-to-v7.mmd b/docs/migration/migration-versioned-example-v5-to-v7.mmd new file mode 100644 index 0000000..51cbc07 --- /dev/null +++ b/docs/migration/migration-versioned-example-v5-to-v7.mmd @@ -0,0 +1,18 @@ +flowchart TB + %% Versioned Example: Current v5 -> Target v7 + + B0["Current DB version: 5"] --> B1["Target version: 7"] + B1 --> B2["setup.sql (optional, idempotent)"] + B2 --> B3["fdw.up.sql (optional)"] + + B3 --> V1["006_fix_constraint.before.sql (optional)"] + V1 --> V2["AutoMigrate(6)"] + V2 --> V3["006_fix_constraint.after.sql (optional)"] + V3 --> V3b["Record version 6"] + + V3b --> V4["007_add_feature.before.sql (optional)"] + V4 --> V5["AutoMigrate(7)"] + V5 --> V6["007_add_feature.after.sql (optional)"] + V6 --> V6b["Record version 7"] + + V6b --> B4["fdw.down.sql (optional)"] diff --git a/docs/migration/migration-versioned.mmd b/docs/migration/migration-versioned.mmd new file mode 100644 index 0000000..4d66d1f --- /dev/null +++ b/docs/migration/migration-versioned.mmd @@ -0,0 +1,12 @@ +flowchart TB + %% Versioned Migration Logic (WithVersionedMigrationFunc) + + B0["Start migrations"] --> B1["setup.sql (optional, idempotent)"] + B1 --> B2["fdw.up.sql (optional)"] + B2 --> V1["For each version v = current+1 .. target"] + V1 --> V2["{version}_{name}.before.sql or .before.up.sql (optional)"] + V2 --> V3["AutoMigrate(version)"] + V3 --> V4["{version}_{name}.after.sql or .after.up.sql (optional)"] + V4 --> V5["Record version v"] + V5 -.-> V1 + V5 --> B3["fdw.down.sql (optional)"] diff --git a/pkg/db/gorm.go b/pkg/db/gorm.go index 4d4b6ab..9135ef5 100644 --- a/pkg/db/gorm.go +++ b/pkg/db/gorm.go @@ -2,6 +2,8 @@ package db import ( "context" + "database/sql" + stderrors "errors" "time" "github.com/pkg/errors" @@ -59,7 +61,7 @@ func Initialize(runCtx context.Context, opts *ConnectionOptions) <-chan struct{} defer logging.LogInfof("database connection closed") }() - err = runMigration(conn, opts.MigrationFunc, opts.MigrationVersion, opts.MigrationStartFromZero) + err = runMigration(conn, opts.MigrationFunc, opts.VersionedMigrationFunc, opts.MigrationVersion, opts.MigrationStartFromZero) if err != nil { if opts.MigrationHaltOnError { logging.LogErrorf(err, "database migration failed - aborting") @@ -149,28 +151,180 @@ func retryExponential(runCtx context.Context, attempts uint, waitPeriod time.Dur } // runMigration Executes Migrations on the database -func runMigration(conn *gorm.DB, migFn MigrationFunc, migrationVersion uint, startFromZero bool) error { +func runMigration( + conn *gorm.DB, + legacyFn MigrationFunc, + versionedFn VersionedMigrationFunc, + migrationVersion uint, + startFromZero bool, +) error { if conn == nil { logging.LogErrorf(ErrDBConnection, "MigrateDB() - db handle is nil") return ErrDBConnection } - // Run GORM automigrations as supplied by service - err := migFn(conn) + if legacyFn != nil && versionedFn != nil { + return errors.New("both MigrationFunc (legacy) and VersionedMigrationFunc are set; please configure only one migration flow") + } + if migrationVersion == 0 { + // No SQL migrations; run whichever automigration function is provided. + if versionedFn != nil { + return versionedFn(conn, 0) + } + if legacyFn != nil { + return legacyFn(conn) + } + return nil + } + + // Prefer explicit versioned flow when configured. + if versionedFn != nil { + return runMigrationVersioned(conn, versionedFn, migrationVersion, startFromZero) + } + + sqlDB, err := conn.DB() if err != nil { + logging.LogErrorf(err, "error getting sql DB") return err } + return runMigrationLegacy(sqlDB, conn, legacyFn, migrationVersion, startFromZero) +} + +func runMigrationLegacy(sqlDB *sql.DB, conn *gorm.DB, legacyFn MigrationFunc, migrationVersion uint, startFromZero bool) error { + // Preserve legacy behavior: AutoMigrate once, then SQL migrations via golang-migrate. + if legacyFn != nil { + if err := legacyFn(conn); err != nil { + return err + } + } + + if migrationVersion == 0 { + return nil + } + migration := migrate.NewMigration(sqlDB, migrationsSource, migrationsTable, logging.Logger()) + return migration.MigrateDB(context.Background(), migrationVersion, startFromZero) +} + +func runMigrationVersioned(conn *gorm.DB, migFn VersionedMigrationFunc, migrationVersion uint, startFromZero bool) error { sqlDB, err := conn.DB() if err != nil { logging.LogErrorf(err, "error getting sql DB") return err } - // Run manual migrations defined in sql scripts if needed - if migrationVersion > 0 { - migration := migrate.NewMigration(sqlDB, migrationsSource, migrationsTable, logging.Logger()) - err = migration.MigrateDB(context.Background(), migrationVersion, startFromZero) + ctx := context.Background() + migration := migrate.NewMigration(sqlDB, migrationsSource, migrationsTable, logging.Logger()) + + if err := migration.ExecuteSetup(ctx); err != nil { + return err + } + if err := migration.ExecuteFdwUp(ctx); err != nil { + return err + } + defer func() { + if err := migration.ExecuteFdwDown(ctx); err != nil { + logging.LogErrorf(err, "error executing fdw down script") + } + }() + + return runMigrationVersions(ctx, conn, migration, migFn, migrationVersion, startFromZero) +} + +func runMigrationVersions( + ctx context.Context, + conn *gorm.DB, + migration *migrate.Migration, + migFn VersionedMigrationFunc, + migrationVersion uint, + startFromZero bool, +) error { + mpg, cleanup, err := migration.MigrateInstanceForVersionTracking() + if err != nil { + return err + } + if cleanup != nil { + defer cleanup() + } + + currentVersion, dirty, needsRecordTarget, err := currentMigrationVersion(mpg, migrationVersion, startFromZero) + if err != nil { + return err + } + if dirty { + return errors.Errorf("database migration is dirty at version %d", currentVersion) + } + + // Legacy behavior: when the database has no version info and startFromZero is false, + // run AutoMigrate once, then record the target version without running per-version hooks. + if needsRecordTarget { + if migFn != nil { + if err := migFn(conn, migrationVersion); err != nil { + logging.LogErrorf(err, "error running auto migration for version %d", migrationVersion) + return err + } + } + if err := migrate.SetVersion(mpg, migrationVersion); err != nil { + logging.LogErrorf(err, "error setting migration version to %d", migrationVersion) + return err + } + return nil + } + + for version := currentVersion + 1; version <= migrationVersion; version++ { + if err := applyMigrationVersion(ctx, conn, migration, migFn, mpg, version); err != nil { + return err + } + logging.LogInfof("migration for version %d executed successfully", version) + } + + return nil +} + +func currentMigrationVersion(mpg migrate.VersionSetter, migrationVersion uint, startFromZero bool) (uint, bool, bool, error) { + currentVersion, dirty, err := mpg.Version() + if err == nil { + return currentVersion, dirty, false, nil + } + if !stderrors.Is(err, migrate.ErrNilVersion) { + return 0, false, false, err + } + if startFromZero { + return 0, false, false, nil + } + // Caller should run a single AutoMigrate and then record the target version. + return migrationVersion, false, true, nil +} + +func applyMigrationVersion( + ctx context.Context, + conn *gorm.DB, + migration *migrate.Migration, + migFn VersionedMigrationFunc, + mpg migrate.VersionSetter, + version uint, +) error { + if _, err := migration.ExecuteBeforeUp(ctx, version); err != nil { + logging.LogErrorf(err, "error running before migration for version %d", version) + return err + } + + if migFn != nil { + if err := migFn(conn, version); err != nil { + logging.LogErrorf(err, "error running auto migration for version %d", version) + return err + } + } + + if _, err := migration.ExecuteAfterUp(ctx, version); err != nil { + logging.LogErrorf(err, "error running after migration for version %d", version) + return err + } + + // Record version after full before/auto/after sequence. + if err := migrate.SetVersion(mpg, version); err != nil { + logging.LogErrorf(err, "error setting migration version to %d", version) + return err } - return err + return nil } diff --git a/pkg/db/gorm_migration_guard_test.go b/pkg/db/gorm_migration_guard_test.go new file mode 100644 index 0000000..c7e8b52 --- /dev/null +++ b/pkg/db/gorm_migration_guard_test.go @@ -0,0 +1,17 @@ +package db + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestRunMigrationRejectsBothLegacyAndVersionedFuncs(t *testing.T) { + legacyFn := func(_ *gorm.DB) error { return nil } + versionedFn := func(_ *gorm.DB, _ uint) error { return nil } + + err := runMigration(&gorm.DB{}, legacyFn, versionedFn, 1, true) + if err == nil { + t.Fatalf("expected error") + } +} diff --git a/pkg/db/migration_version_test.go b/pkg/db/migration_version_test.go new file mode 100644 index 0000000..ba85baf --- /dev/null +++ b/pkg/db/migration_version_test.go @@ -0,0 +1,94 @@ +package db + +import ( + "errors" + "testing" + + "github.com/d4l-data4life/go-svc/pkg/migrate" +) + +type fakeVersionSetter struct { + version uint + dirty bool + err error + forced []int + forceErr error +} + +func (f *fakeVersionSetter) Version() (uint, bool, error) { + return f.version, f.dirty, f.err +} + +func (f *fakeVersionSetter) Force(v int) error { + f.forced = append(f.forced, v) + return f.forceErr +} + +func TestCurrentMigrationVersion_StartFromZero(t *testing.T) { + setter := &fakeVersionSetter{err: migrate.ErrNilVersion} + version, dirty, needsRecordTarget, err := currentMigrationVersion(setter, 5, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dirty { + t.Fatalf("expected dirty=false") + } + if needsRecordTarget { + t.Fatalf("expected needsRecordTarget=false") + } + if version != 0 { + t.Fatalf("expected version 0, got %d", version) + } + if len(setter.forced) != 0 { + t.Fatalf("expected no Force calls, got %v", setter.forced) + } +} + +func TestCurrentMigrationVersion_NeedsRecordTarget(t *testing.T) { + setter := &fakeVersionSetter{err: migrate.ErrNilVersion} + version, dirty, needsRecordTarget, err := currentMigrationVersion(setter, 7, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dirty { + t.Fatalf("expected dirty=false") + } + if !needsRecordTarget { + t.Fatalf("expected needsRecordTarget=true") + } + if version != 7 { + t.Fatalf("expected version 7, got %d", version) + } + if len(setter.forced) != 0 { + t.Fatalf("expected no Force calls, got %v", setter.forced) + } +} + +func TestCurrentMigrationVersion_PropagatesDirty(t *testing.T) { + setter := &fakeVersionSetter{version: 3, dirty: true} + version, dirty, needsRecordTarget, err := currentMigrationVersion(setter, 7, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !dirty { + t.Fatalf("expected dirty=true") + } + if needsRecordTarget { + t.Fatalf("expected needsRecordTarget=false") + } + if version != 3 { + t.Fatalf("expected version 3, got %d", version) + } +} + +func TestCurrentMigrationVersion_PropagatesError(t *testing.T) { + boom := errors.New("boom") + setter := &fakeVersionSetter{err: boom} + version, dirty, needsRecordTarget, err := currentMigrationVersion(setter, 7, false) + _ = version + _ = dirty + _ = needsRecordTarget + if !errors.Is(err, boom) { + t.Fatalf("expected error %v, got %v", boom, err) + } +} diff --git a/pkg/db/options.go b/pkg/db/options.go index 948d8a7..cdf6af9 100644 --- a/pkg/db/options.go +++ b/pkg/db/options.go @@ -18,6 +18,7 @@ import ( const SSLVerifyFull = "verify-full" type MigrationFunc func(do *gorm.DB) error +type VersionedMigrationFunc func(do *gorm.DB, version uint) error type DriverFunc func(connectString string, opts *ConnectionOptions) (*gorm.DB, error) func NewConnection(opts ...ConnectionOption) *ConnectionOptions { @@ -50,6 +51,7 @@ type ConnectionOptions struct { // The cert is provided by Jenkins on build under default path "/root.ca.pem" SSLRootCertPath string MigrationFunc MigrationFunc + VersionedMigrationFunc VersionedMigrationFunc DriverFunc DriverFunc EnableInstrumentation bool LoggerConfig logger.Config @@ -147,6 +149,12 @@ func WithMigrationFunc(fn MigrationFunc) ConnectionOption { } } +func WithVersionedMigrationFunc(fn VersionedMigrationFunc) ConnectionOption { + return func(c *ConnectionOptions) { + c.VersionedMigrationFunc = fn + } +} + func WithMigrationVersion(version uint) ConnectionOption { return func(c *ConnectionOptions) { c.MigrationVersion = version diff --git a/pkg/db/options_test.go b/pkg/db/options_test.go new file mode 100644 index 0000000..adbd671 --- /dev/null +++ b/pkg/db/options_test.go @@ -0,0 +1,23 @@ +package db + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestWithMigrationFuncDoesNotSetVersionedMigrationFunc(t *testing.T) { + fn := func(_ *gorm.DB) error { return nil } + + opts := NewConnection( + WithMigrationVersion(2), + WithMigrationFunc(fn), + ) + + if opts.MigrationFunc == nil { + t.Fatalf("expected MigrationFunc to be set") + } + if opts.VersionedMigrationFunc != nil { + t.Fatalf("expected VersionedMigrationFunc to be nil when only WithMigrationFunc is used") + } +} diff --git a/pkg/db/testing.go b/pkg/db/testing.go index 128dd57..4f3cb97 100644 --- a/pkg/db/testing.go +++ b/pkg/db/testing.go @@ -38,8 +38,8 @@ func InitializeTestPostgres(opts *ConnectionOptions) { logging.LogErrorf(err, "error connecting to testing postgres") db = nil } - if opts.MigrationFunc != nil { - if err = runMigration(conn, opts.MigrationFunc, opts.MigrationVersion, true); err != nil { + if opts.MigrationFunc != nil || opts.VersionedMigrationFunc != nil || opts.MigrationVersion > 0 { + if err = runMigration(conn, opts.MigrationFunc, opts.VersionedMigrationFunc, opts.MigrationVersion, true); err != nil { logging.LogErrorf(err, "test DB migration error") } } diff --git a/pkg/log/audit.go b/pkg/log/audit.go index fbe6fab..14e6275 100644 --- a/pkg/log/audit.go +++ b/pkg/log/audit.go @@ -68,6 +68,7 @@ func (l *Logger) createBaseAuditLog(ctx context.Context, logType AuditLogType) b // The expected context keys are "trace-id" and "user-id". // This is the log type to use when a message should be accompanied // with an object relevant for auditing, e.g., new set of permissions. +// // Deprecated: use AuditSecurity instead. func (l *Logger) Audit( ctx context.Context, diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go index 4cac7f7..4a1758f 100644 --- a/pkg/logging/logger.go +++ b/pkg/logging/logger.go @@ -174,7 +174,13 @@ func LogWarningfCtx(ctx context.Context, err error, format string, fields ...int // LogAudit logs a generic audit event containing of a message along with an object pertaining to the message. func LogAudit(ctx context.Context, message string, object any) { - if err := Logger().Audit(ctx, message, object); err != nil { + if err := Logger().AuditSecurity( + ctx, + "audit", + true, + golog.Message(message), + golog.AdditionalData(object), + ); err != nil { fmt.Printf("Logging error (LogAudit): %s\n", err.Error()) } } diff --git a/pkg/migrate/README.md b/pkg/migrate/README.md index 1ebbf49..eab13b6 100644 --- a/pkg/migrate/README.md +++ b/pkg/migrate/README.md @@ -1,17 +1,23 @@ # go-pg-migrate -Library for migrating the Postgres Database of PHDP services. It uses [golang-migrate V4](https://github.com/golang-migrate/migrate) for the migration. +Library for migrating Postgres databases in services using `pkg/db`. It is built +on top of [golang-migrate V4](https://github.com/golang-migrate/migrate). ## Setup Script -`go-pg-migrate` allows to run a setup script before the migration steps that will be handled by `golang-migrate`. -The script is optional and must be called `setup.sql` and be placed in the same folder as the other sql scripts. -The main use case for the setup script is creating an schema that will be used by `golang-migrate` for the migration table itself. -The setup script must be idempotent, as it will be run for every migration (unlike the migration steps that are skipped if the version is already present). +`go-pg-migrate` allows running a setup script before the migration steps handled +by `golang-migrate`. The script is optional and must be called `setup.sql` and +placed in the same folder as the other SQL scripts. -## Postgres foreign-data wrapper +The main use case is creating a schema that will be used by `golang-migrate` for +the migration table itself. The setup script must be idempotent, as it will be +run for every migration invocation (unlike migration steps which are skipped if +the version is already present). -`go-pg-migrate` allows to run additional scripts before and after the migration which are golang templated by a ForeignDatabase struct and the following fields: +## Postgres Foreign Data Wrapper (FDW) + +`go-pg-migrate` can run additional scripts before and after the migration which +are templated by a `ForeignDatabase` struct: - LocalUser string - DBName string @@ -20,10 +26,39 @@ The setup script must be idempotent, as it will be run for every migration (unli - User string - Password string -The scripts are optional and must be called `fdw.up.sql` and `fdw.down.sql` and be placed in the same folder as the other sql scripts. The placeholders can be used like this well-known notation within the scripts: `{{.LocalUser}}`. -The main use case for the scripts is to prepare the database for some foreign data migration like described in [Postgres FDW](https://www.postgresql.org/docs/12/postgres-fdw.html). +The scripts are optional and must be called `fdw.up.sql` and `fdw.down.sql`, and +placed in the same folder as the other SQL scripts. Placeholders can be used via +`{{.LocalUser}}` syntax. The main use case is preparing the database for foreign +data migration (see Postgres FDW docs). ## Migration Table -`golang-migrate` needs a table that will contain the migration metadata (current version and the dirty status). This table will be created by the library with the given table name. -However, the schema where the table is created is not configurable for postgres as of version 4 of `golang-migrate`. Instead, the `golang-migrate` library will create the table with the unqualified name, which will have the effect of creating the table in the current schema. Therefore, if the table is intended to be created in a particular schema, that schema needs to be set as the current schema (first element in the search path). +`golang-migrate` uses a table that contains migration metadata (current version +and dirty status). The table is created with the given name. For Postgres, the +schema is not configurable as of v4, so the table is created in the current +schema (first element in the search path). If the table must live in a specific +schema, that schema must be in the search path. + +## Migration Flows + +`pkg/db` supports two flows: legacy and versioned. + +### Legacy Flow + +- Single AutoMigrate (latest models). +- SQL migrations executed via `golang-migrate`: + - `{version}_{name}.up.sql` / `{version}_{name}.down.sql` + +### Versioned Flow + +- Interleaves per migration version: + 1. Versioned before script (optional) + 2. AutoMigrate(version) (service implementation) + 3. Versioned after script (optional) + 4. Record version after the full sequence completes + +Missing before/after scripts are skipped. + +**Supported naming (versioned flow only):** +- Before: `{version}_{name}.before.sql` or `{version}_{name}.before.up.sql` +- After: `{version}_{name}.after.sql` or `{version}_{name}.after.up.sql` diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go index 90fc64e..8838b59 100644 --- a/pkg/migrate/migrate.go +++ b/pkg/migrate/migrate.go @@ -3,8 +3,11 @@ package migrate import ( "context" "database/sql" + stderrors "errors" "fmt" "os" + "path/filepath" + "strconv" "strings" "text/template" @@ -21,8 +24,16 @@ const ( setupScriptName = "setup.sql" fdwUpScriptName = "fdw.up.sql" fdwDownScriptName = "fdw.down.sql" + beforeUpSuffix = ".before.up.sql" + beforeDownSuffix = ".before.down.sql" + afterUpSuffix = ".after.up.sql" + beforeSuffix = ".before.sql" + afterSuffix = ".after.sql" ) +// ErrNilVersion is returned when no migration version is set in the database. +var ErrNilVersion = migrate.ErrNilVersion + // Migration is the struct that holds the information needed for migrating a database. type Migration struct { db *sql.DB @@ -83,11 +94,48 @@ func (m *Migration) MigrateDB(ctx context.Context, migrationVersion uint, startF return errors.Wrap(err, "could not run the fdw.up script") } + if err := m.MigrateToVersion(ctx, migrationVersion, startFromZero); err != nil { + return err + } + + if err := m.execute(ctx, fdwDownScriptName, m.foreignDatabase); err != nil { // execute fdw.down + return errors.Wrap(err, "could not run the fdw.down script") + } + + return nil +} + +// ExecuteSetup runs the setup.sql script if present. +func (m *Migration) ExecuteSetup(ctx context.Context) error { + if err := m.execute(ctx, setupScriptName, nil); err != nil { + return errors.Wrap(err, "could not run the setup script") + } + return nil +} + +// ExecuteFdwUp runs the fdw.up.sql script if present. +func (m *Migration) ExecuteFdwUp(ctx context.Context) error { + if err := m.execute(ctx, fdwUpScriptName, m.foreignDatabase); err != nil { + return errors.Wrap(err, "could not run the fdw.up script") + } + return nil +} + +// ExecuteFdwDown runs the fdw.down.sql script if present. +func (m *Migration) ExecuteFdwDown(ctx context.Context) error { + if err := m.execute(ctx, fdwDownScriptName, m.foreignDatabase); err != nil { + return errors.Wrap(err, "could not run the fdw.down script") + } + return nil +} + +// MigrateInstance creates a golang-migrate instance for the current source folder. +func (m *Migration) MigrateInstance() (*migrate.Migrate, error) { driver, err := postgres.WithInstance(m.db, &postgres.Config{ MigrationsTable: m.migrationTable, }) if err != nil { - return errors.Wrap(err, "error creating database driver") + return nil, errors.Wrap(err, "error creating database driver") } mpg, err := migrate.NewWithDatabaseInstance( @@ -96,11 +144,63 @@ func (m *Migration) MigrateDB(ctx context.Context, migrationVersion uint, startF driver, ) if err != nil { - return errors.Wrap(err, "error creating migrate instance") + return nil, errors.Wrap(err, "error creating migrate instance") + } + return mpg, nil +} + +// MigrateInstanceForVersionTracking creates a migrate instance that excludes before/after scripts. +func (m *Migration) MigrateInstanceForVersionTracking() (*migrate.Migrate, func(), error) { + sourceFolder, cleanup, err := CreateVersionSourceFolder(m.sourceFolder) + if err != nil { + return nil, nil, err + } + driver, err := postgres.WithInstance(m.db, &postgres.Config{ + MigrationsTable: m.migrationTable, + }) + if err != nil { + if cleanup != nil { + cleanup() + } + return nil, nil, errors.Wrap(err, "error creating database driver") + } + mpg, err := migrate.NewWithDatabaseInstance( + "file://"+sourceFolder, + "postgres", + driver, + ) + if err != nil { + if cleanup != nil { + cleanup() + } + return nil, nil, errors.Wrap(err, "error creating migrate instance") + } + return mpg, cleanup, nil +} + +// CurrentVersion returns the current migration version. +func (m *Migration) CurrentVersion() (uint, bool, error) { + mpg, err := m.MigrateInstance() + if err != nil { + return 0, false, err + } + + version, dirty, err := mpg.Version() + if err != nil { + return 0, false, err + } + return version, dirty, nil +} + +// MigrateToVersion runs golang-migrate without setup/fdw scripts. +func (m *Migration) MigrateToVersion(ctx context.Context, migrationVersion uint, startFromZero bool) error { + mpg, err := m.MigrateInstance() + if err != nil { + return err } _, _, err = mpg.Version() - if err == migrate.ErrNilVersion && !startFromZero { + if stderrors.Is(err, migrate.ErrNilVersion) && !startFromZero { // no migration information in the database, so it's a fresh database // and the data model is already the latest one set up Gorm automigrations // nolint: gosec @@ -123,13 +223,24 @@ func (m *Migration) MigrateDB(ctx context.Context, migrationVersion uint, startF return errors.Wrap(err, fmt.Sprintf("error migrating database to v%d", migrationVersion)) } - if err := m.execute(ctx, fdwDownScriptName, m.foreignDatabase); err != nil { // execute fdw.down - return errors.Wrap(err, "could not run the fdw.down script") - } + return nil +} +// SetVersion records the current version in the migrations table using an existing migrate instance. +func SetVersion(mpg VersionSetter, migrationVersion uint) error { + // nolint: gosec + if err := mpg.Force(int(migrationVersion)); err != nil { + return errors.Wrap(err, "error setting migration version") + } return nil } +// VersionSetter abstracts a migrate instance that can report and set version. +type VersionSetter interface { + Version() (uint, bool, error) + Force(int) error +} + func (m *Migration) parseFile(ctx context.Context, filename string, templateData interface{}) (string, error) { path := m.sourceFolder + "/" + filename @@ -181,6 +292,162 @@ func (m *Migration) execute(ctx context.Context, filename string, templateData i return err } +// ExecuteBeforeUp runs a versioned before migration if present. +func (m *Migration) ExecuteBeforeUp(ctx context.Context, migrationVersion uint) (bool, error) { + filename, err := findBeforeUpFile(m.sourceFolder, migrationVersion) + if err != nil { + return false, errors.Wrap(err, "could not scan for before migration") + } + if filename == "" { + _ = m.log.InfoGeneric(ctx, fmt.Sprintf("no before migration found for version %d - skipped", migrationVersion)) + return false, nil + } + if err := m.execute(ctx, filename, nil); err != nil { + return false, errors.Wrap(err, fmt.Sprintf("could not run before migration %q", filename)) + } + return true, nil +} + +// ExecuteAfterUp runs a versioned after migration if present. +func (m *Migration) ExecuteAfterUp(ctx context.Context, migrationVersion uint) (bool, error) { + filename, err := findAfterUpFile(m.sourceFolder, migrationVersion) + if err != nil { + return false, errors.Wrap(err, "could not scan for after migration") + } + if filename == "" { + _ = m.log.InfoGeneric(ctx, fmt.Sprintf("no after migration found for version %d - skipped", migrationVersion)) + return false, nil + } + if err := m.execute(ctx, filename, nil); err != nil { + return false, errors.Wrap(err, fmt.Sprintf("could not run after migration %q", filename)) + } + return true, nil +} + +// CreateAfterSourceFolder returns a temp folder containing only non-before migrations. +func CreateAfterSourceFolder(sourceFolder string) (string, func(), error) { + entries, err := os.ReadDir(sourceFolder) + if err != nil { + return "", nil, errors.Wrap(err, "could not read migrations folder") + } + tempDir, err := os.MkdirTemp("", "migrate-after-*") + if err != nil { + return "", nil, errors.Wrap(err, "could not create temp folder") + } + cleanup := func() { + _ = os.RemoveAll(tempDir) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if isBeforeMigrationFile(name) { + continue + } + if err := copyFile(filepath.Join(sourceFolder, name), filepath.Join(tempDir, name)); err != nil { + cleanup() + return "", nil, err + } + } + return tempDir, cleanup, nil +} + +// CreateVersionSourceFolder returns a temp folder with only tracked migration files. +// It excludes before/after scripts to avoid duplicate version conflicts. +func CreateVersionSourceFolder(sourceFolder string) (string, func(), error) { + entries, err := os.ReadDir(sourceFolder) + if err != nil { + return "", nil, errors.Wrap(err, "could not read migrations folder") + } + tempDir, err := os.MkdirTemp("", "migrate-version-*") + if err != nil { + return "", nil, errors.Wrap(err, "could not create temp folder") + } + cleanup := func() { + _ = os.RemoveAll(tempDir) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if isBeforeMigrationFile(name) || isAfterMigrationFile(name) { + continue + } + if !strings.HasSuffix(name, ".up.sql") && !strings.HasSuffix(name, ".down.sql") { + continue + } + if err := copyFile(filepath.Join(sourceFolder, name), filepath.Join(tempDir, name)); err != nil { + cleanup() + return "", nil, err + } + } + return tempDir, cleanup, nil +} + +// CreateAfterSourceFolderForVersion returns a temp folder with the after migration for a single version. +func CreateAfterSourceFolderForVersion(sourceFolder string, migrationVersion uint) (string, func(), error) { + entries, err := os.ReadDir(sourceFolder) + if err != nil { + return "", nil, errors.Wrap(err, "could not read migrations folder") + } + tempDir, err := os.MkdirTemp("", "migrate-after-*") + if err != nil { + return "", nil, errors.Wrap(err, "could not create temp folder") + } + cleanup := func() { + _ = os.RemoveAll(tempDir) + } + copied, err := copyAfterMigrationForVersion(entries, sourceFolder, tempDir, migrationVersion) + if err != nil { + cleanup() + return "", nil, err + } + if !copied { + noopName := fmt.Sprintf("%d_noop.up.sql", migrationVersion) + if err := os.WriteFile(filepath.Join(tempDir, noopName), []byte("SELECT 1;"), 0o600); err != nil { + cleanup() + return "", nil, errors.Wrap(err, fmt.Sprintf("could not write %q", noopName)) + } + } + return tempDir, cleanup, nil +} + +func copyAfterMigrationForVersion(entries []os.DirEntry, sourceFolder, tempDir string, migrationVersion uint) (bool, error) { + copied := false + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + targetName, ok := afterMigrationTargetName(name) + if !ok { + continue + } + version, ok := parseMigrationVersion(name) + if !ok || version != migrationVersion { + continue + } + if err := copyFile(filepath.Join(sourceFolder, name), filepath.Join(tempDir, targetName)); err != nil { + return false, err + } + copied = true + } + return copied, nil +} + +func afterMigrationTargetName(filename string) (string, bool) { + switch { + case strings.HasSuffix(filename, afterUpSuffix): + return strings.TrimSuffix(filename, afterUpSuffix) + ".up.sql", true + case strings.HasSuffix(filename, afterSuffix): + return strings.TrimSuffix(filename, afterSuffix) + ".up.sql", true + default: + return "", false + } +} + func fileExists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { @@ -193,3 +460,130 @@ func fileExists(path string) (bool, error) { // file may exists but os.Stat fails for other reasons (eg. permission, failing disk) return false, err } + +func isBeforeMigrationFile(filename string) bool { + return strings.HasSuffix(filename, beforeUpSuffix) || + strings.HasSuffix(filename, beforeDownSuffix) || + strings.HasSuffix(filename, beforeSuffix) +} + +func isAfterMigrationFile(filename string) bool { + return strings.HasSuffix(filename, afterUpSuffix) || strings.HasSuffix(filename, afterSuffix) +} + +func findBeforeUpFile(sourceFolder string, migrationVersion uint) (string, error) { + return findHookFile(sourceFolder, migrationVersion, beforeUpSuffix, beforeSuffix, "before") +} + +func findAfterUpFile(sourceFolder string, migrationVersion uint) (string, error) { + return findHookFile(sourceFolder, migrationVersion, afterUpSuffix, afterSuffix, "after") +} + +type hookFileKind uint8 + +const ( + hookFileNone hookFileKind = iota + hookFileUp + hookFilePlain +) + +func findHookFile(sourceFolder string, migrationVersion uint, suffixUp, suffixPlain, hookName string) (string, error) { + entries, err := os.ReadDir(sourceFolder) + if err != nil { + return "", err + } + foundUp, foundPlain, err := scanHookFiles(entries, migrationVersion, suffixUp, suffixPlain, hookName) + if err != nil { + return "", err + } + if foundUp != "" && foundPlain != "" { + return "", fmt.Errorf("conflicting %s migrations found for version %d: %q and %q", hookName, migrationVersion, foundUp, foundPlain) + } + if foundUp != "" { + return foundUp, nil + } + return foundPlain, nil +} + +func scanHookFiles(entries []os.DirEntry, migrationVersion uint, suffixUp, suffixPlain, hookName string) (string, string, error) { + var foundUp string + var foundPlain string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + kind, ok := classifyHookFile(name, migrationVersion, suffixUp, suffixPlain) + if !ok { + continue + } + switch kind { + case hookFileUp: + if foundUp != "" && foundUp != name { + return "", "", fmt.Errorf( + "multiple %s migrations found for version %d: %q and %q", + hookName, + migrationVersion, + foundUp, + name, + ) + } + foundUp = name + case hookFilePlain: + if foundPlain != "" && foundPlain != name { + return "", "", fmt.Errorf( + "multiple %s migrations found for version %d: %q and %q", + hookName, + migrationVersion, + foundPlain, + name, + ) + } + foundPlain = name + default: + continue + } + } + return foundUp, foundPlain, nil +} + +func classifyHookFile(filename string, migrationVersion uint, suffixUp, suffixPlain string) (hookFileKind, bool) { + switch { + case strings.HasSuffix(filename, suffixUp): + version, ok := parseMigrationVersion(filename) + return hookFileUp, ok && version == migrationVersion + case strings.HasSuffix(filename, suffixPlain): + version, ok := parseMigrationVersion(filename) + return hookFilePlain, ok && version == migrationVersion + default: + return hookFileNone, false + } +} + +func parseMigrationVersion(filename string) (uint, bool) { + base := filepath.Base(filename) + sep := strings.Index(base, "_") + if sep <= 0 { + return 0, false + } + versionStr := base[:sep] + if len(versionStr) == 0 { + return 0, false + } + parsed, err := strconv.ParseUint(versionStr, 10, 32) + if err != nil { + return 0, false + } + return uint(parsed), true +} + +func copyFile(src, dest string) error { + data, err := os.ReadFile(src) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("could not read %q", src)) + } + if err := os.WriteFile(dest, data, 0o600); err != nil { + return errors.Wrap(err, fmt.Sprintf("could not write %q", dest)) + } + return nil +} diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index ca7071b..a774b46 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -3,6 +3,8 @@ package migrate import ( "context" "log" + "os" + "path/filepath" "testing" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -102,6 +104,203 @@ func TestMigration_parseFile(t *testing.T) { } } +func TestFindBeforeUpFile(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "001_init.up.sql") + writeMigrationFile(t, dir, "002_add.before.up.sql") + writeMigrationFile(t, dir, "002_add.before.down.sql") + writeMigrationFile(t, dir, "003_other.before.up.sql") + writeMigrationFile(t, dir, "004_new.before.sql") + + found, err := findBeforeUpFile(dir, 2) + if err != nil { + t.Fatalf("findBeforeUpFile() error = %v", err) + } + if found != "002_add.before.up.sql" { + t.Fatalf("findBeforeUpFile() got = %q, want %q", found, "002_add.before.up.sql") + } + + found, err = findBeforeUpFile(dir, 5) + if err != nil { + t.Fatalf("findBeforeUpFile() error = %v", err) + } + if found != "" { + t.Fatalf("findBeforeUpFile() got = %q, want empty", found) + } + + found, err = findBeforeUpFile(dir, 4) + if err != nil { + t.Fatalf("findBeforeUpFile() error = %v", err) + } + if found != "004_new.before.sql" { + t.Fatalf("findBeforeUpFile() got = %q, want %q", found, "004_new.before.sql") + } +} + +func TestFindBeforeUpFile_Conflicts(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "004_new.before.sql") + writeMigrationFile(t, dir, "004_new.before.up.sql") + + _, err := findBeforeUpFile(dir, 4) + if err == nil { + t.Fatalf("expected conflict error") + } +} + +func TestFindAfterUpFile(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "001_init.after.up.sql") + writeMigrationFile(t, dir, "002_add.after.up.sql") + writeMigrationFile(t, dir, "003_other.after.up.sql") + writeMigrationFile(t, dir, "004_new.after.sql") + + found, err := findAfterUpFile(dir, 2) + if err != nil { + t.Fatalf("findAfterUpFile() error = %v", err) + } + if found != "002_add.after.up.sql" { + t.Fatalf("findAfterUpFile() got = %q, want %q", found, "002_add.after.up.sql") + } + + found, err = findAfterUpFile(dir, 5) + if err != nil { + t.Fatalf("findAfterUpFile() error = %v", err) + } + if found != "" { + t.Fatalf("findAfterUpFile() got = %q, want empty", found) + } + + found, err = findAfterUpFile(dir, 4) + if err != nil { + t.Fatalf("findAfterUpFile() error = %v", err) + } + if found != "004_new.after.sql" { + t.Fatalf("findAfterUpFile() got = %q, want %q", found, "004_new.after.sql") + } +} + +func TestFindAfterUpFile_Conflicts(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "004_new.after.sql") + writeMigrationFile(t, dir, "004_new.after.up.sql") + + _, err := findAfterUpFile(dir, 4) + if err == nil { + t.Fatalf("expected conflict error") + } +} + +func TestCreateAfterSourceFolder_excludesBefore(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "001_init.up.sql") + writeMigrationFile(t, dir, "002_add.before.up.sql") + writeMigrationFile(t, dir, "002_add.before.down.sql") + writeMigrationFile(t, dir, "setup.sql") + + afterDir, cleanup, err := CreateAfterSourceFolder(dir) + if err != nil { + t.Fatalf("CreateAfterSourceFolder() error = %v", err) + } + if cleanup != nil { + defer cleanup() + } + + entries, err := os.ReadDir(afterDir) + if err != nil { + t.Fatalf("ReadDir() error = %v", err) + } + found := map[string]bool{} + for _, entry := range entries { + found[entry.Name()] = true + } + if found["002_add.before.up.sql"] || found["002_add.before.down.sql"] { + t.Fatalf("CreateAfterSourceFolder() should exclude before files") + } + if !found["001_init.up.sql"] || !found["setup.sql"] { + t.Fatalf("CreateAfterSourceFolder() should keep non-before files") + } +} + +func TestCreateAfterSourceFolderForVersion(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "001_init.after.up.sql") + writeMigrationFile(t, dir, "002_add.after.up.sql") + writeMigrationFile(t, dir, "003_other.after.up.sql") + + afterDir, cleanup, err := CreateAfterSourceFolderForVersion(dir, 2) + if err != nil { + t.Fatalf("CreateAfterSourceFolderForVersion() error = %v", err) + } + if cleanup != nil { + defer cleanup() + } + if afterDir == "" { + t.Fatalf("CreateAfterSourceFolderForVersion() returned empty dir") + } + + entries, err := os.ReadDir(afterDir) + if err != nil { + t.Fatalf("ReadDir() error = %v", err) + } + found := map[string]bool{} + for _, entry := range entries { + found[entry.Name()] = true + } + if !found["002_add.up.sql"] { + t.Fatalf("CreateAfterSourceFolderForVersion() should include renamed after file") + } + if found["001_init.up.sql"] || found["003_other.up.sql"] { + t.Fatalf("CreateAfterSourceFolderForVersion() should only include the target version") + } + + afterDir, cleanup, err = CreateAfterSourceFolderForVersion(dir, 4) + if err != nil { + t.Fatalf("CreateAfterSourceFolderForVersion() error = %v", err) + } + if cleanup != nil { + defer cleanup() + } + if afterDir == "" { + t.Fatalf("CreateAfterSourceFolderForVersion() got empty dir") + } + entries, err = os.ReadDir(afterDir) + if err != nil { + t.Fatalf("ReadDir() error = %v", err) + } + found = map[string]bool{} + for _, entry := range entries { + found[entry.Name()] = true + } + if !found["4_noop.up.sql"] { + t.Fatalf("CreateAfterSourceFolderForVersion() should create noop migration") + } +} + +func TestCreateAfterSourceFolderForVersion_afterDotSql(t *testing.T) { + dir := t.TempDir() + writeMigrationFile(t, dir, "004_new.after.sql") + + afterDir, cleanup, err := CreateAfterSourceFolderForVersion(dir, 4) + if err != nil { + t.Fatalf("CreateAfterSourceFolderForVersion() error = %v", err) + } + if cleanup != nil { + defer cleanup() + } + entries, err := os.ReadDir(afterDir) + if err != nil { + t.Fatalf("ReadDir() error = %v", err) + } + found := map[string]bool{} + for _, entry := range entries { + found[entry.Name()] = true + } + if !found["004_new.up.sql"] { + t.Fatalf("CreateAfterSourceFolderForVersion() should include renamed after.sql file") + } +} + var wantParsedFdwUp = `BEGIN; CREATE SERVER IF NOT EXISTS keymgmt_server FOREIGN DATA WRAPPER postgres_fdw OPTIONS (host 'myHostname', dbname 'myDBName', port '42'); @@ -116,3 +315,11 @@ var wantParsedSetup = `CREATE TABLE IF NOT EXISTS test_setup.testtable ( PRIMARY KEY (id) ); ` + +func writeMigrationFile(t *testing.T, dir, name string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("SELECT 1;"), 0o600); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/test/migration_versioned_test.go b/test/migration_versioned_test.go new file mode 100644 index 0000000..4888e20 --- /dev/null +++ b/test/migration_versioned_test.go @@ -0,0 +1,154 @@ +package test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/d4l-data4life/go-svc/pkg/db" + "github.com/d4l-data4life/go-svc/pkg/migrate" + "gorm.io/gorm" +) + +func TestVersionedMigrationFlow(t *testing.T) { + cfg, err := parseEnv() + if err != nil { + t.Fatal(err) + } + + sqlDB, err := connectToDB(cfg) + if err != nil { + t.Fatal(err) + } + defer sqlDB.Close() + + ctx := context.Background() + _ = cleanTable(ctx, sqlDB, "migration_steps") + _ = cleanTable(ctx, sqlDB, "migrations") + defer func() { + _ = cleanTable(ctx, sqlDB, "migration_steps") + _ = cleanTable(ctx, sqlDB, "migrations") + }() + + tmpDir := t.TempDir() + sqlDir := filepath.Join(tmpDir, "sql") + if err := os.MkdirAll(sqlDir, 0o755); err != nil { + t.Fatal(err) + } + + writeSQL(t, sqlDir, "001_init.before.sql", ` +CREATE TABLE IF NOT EXISTS migration_steps ( + seq SERIAL PRIMARY KEY, + step TEXT NOT NULL +); +INSERT INTO migration_steps (step) VALUES ('before-1'); +`) + writeSQL(t, sqlDir, "001_init.after.sql", ` +INSERT INTO migration_steps (step) VALUES ('after-1'); +`) + writeSQL(t, sqlDir, "002_add.before.sql", ` +INSERT INTO migration_steps (step) VALUES ('before-2'); +`) + writeSQL(t, sqlDir, "002_add.after.sql", ` +INSERT INTO migration_steps (step) VALUES ('after-2'); +`) + writeSQL(t, sqlDir, "003_more.before.up.sql", ` +INSERT INTO migration_steps (step) VALUES ('before-3'); +`) + writeSQL(t, sqlDir, "003_more.after.up.sql", ` +INSERT INTO migration_steps (step) VALUES ('after-3'); +`) + + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + if err := os.Chdir(tmpDir); err != nil { + t.Fatal(err) + } + defer func() { + _ = os.Chdir(cwd) + }() + + migFn := func(conn *gorm.DB, version uint) error { + return conn.Exec(fmt.Sprintf("INSERT INTO migration_steps (step) VALUES ('auto-%d')", version)).Error + } + + opts := db.NewConnection( + db.WithHost(cfg.PGHost), + db.WithPort(strconv.FormatUint(uint64(cfg.PGPort), 10)), + db.WithDatabaseName(cfg.PGName), + db.WithUser(cfg.PGUser), + db.WithPassword(cfg.PGPassword), + db.WithSSLMode("disable"), + db.WithMigrationStartFromZero(true), + db.WithMigrationVersion(4), + db.WithVersionedMigrationFunc(migFn), + ) + + db.InitializeTestPostgres(opts) + conn := db.Get() + if conn == nil { + t.Fatal("db handle is nil") + } + + type row struct { + Seq int + Step string + } + rows := []row{} + if err := conn.Raw("SELECT seq, step FROM migration_steps ORDER BY seq").Scan(&rows).Error; err != nil { + t.Fatalf("query steps: %v", err) + } + + want := []string{ + "before-1", + "auto-1", + "after-1", + "before-2", + "auto-2", + "after-2", + "before-3", + "auto-3", + "after-3", + "auto-4", + } + if len(rows) != len(want) { + t.Fatalf("got %d steps, want %d", len(rows), len(want)) + } + for i, w := range want { + if rows[i].Step != w { + t.Fatalf("step %d: got %q, want %q", i, rows[i].Step, w) + } + } + + migration := migrate.NewMigration(sqlDB, sqlDir, "migrations", &testLog{}) + mpg, cleanup, err := migration.MigrateInstanceForVersionTracking() + if err != nil { + t.Fatalf("current version: %v", err) + } + if cleanup != nil { + defer cleanup() + } + version, dirty, err := mpg.Version() + if err != nil { + t.Fatalf("current version: %v", err) + } + if dirty { + t.Fatalf("expected clean migrations table") + } + if version != 4 { + t.Fatalf("expected version 4, got %d", version) + } +} + +func writeSQL(t *testing.T, dir, name, content string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +}