diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index b84485e1d..4f1a5af97 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) @@ -45,13 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -337,13 +332,17 @@ func TestAccountStoreMigration(t *testing.T) { } // Perform the migration. + // + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, sqldb.WriteTxOpt(), + ctx, &opts, func(tx SQLQueries) error { return MigrateAccountStoreToSQL( ctx, kvStore.db, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..c7e8ab070 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +34,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +56,13 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +72,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +179,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +321,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +399,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +434,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +475,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +497,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +529,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +546,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +656,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +699,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +726,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +751,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_sql.go b/accounts/test_sql.go index 3c1ee7f16..ca2f43d6f 100644 --- a/accounts/test_sql.go +++ b/accounts/test_sql.go @@ -5,15 +5,20 @@ package accounts import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 9d899b3e2..a31f990a6 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..cc8a5b319 100644 --- a/config_dev.go +++ b/config_dev.go @@ -8,9 +8,11 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -101,14 +103,36 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +140,41 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Postgres.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..d8fd37ad2 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -1,310 +1,5 @@ package db -import ( - "context" - "database/sql" - "math" - prand "math/rand" - "time" - - "github.com/lightninglabs/lightning-terminal/db/sqlc" -) - -var ( - // DefaultStoreTimeout is the default timeout used for any interaction - // with the storage/database. - DefaultStoreTimeout = time.Second * 10 -) - -const ( - // DefaultNumTxRetries is the default number of times we'll retry a - // transaction if it fails with an error that permits transaction - // repetition. - DefaultNumTxRetries = 10 - - // DefaultInitialRetryDelay is the default initial delay between - // retries. This will be used to generate a random delay between -50% - // and +50% of this value, so 20 to 60 milliseconds. The retry will be - // doubled after each attempt until we reach DefaultMaxRetryDelay. We - // start with a random value to avoid multiple goroutines that are - // created at the same time to effectively retry at the same time. - DefaultInitialRetryDelay = time.Millisecond * 40 - - // DefaultMaxRetryDelay is the default maximum delay between retries. - DefaultMaxRetryDelay = time.Second * 3 -) - -// TxOptions represents a set of options one can use to control what type of -// database transaction is created. Transaction can wither be read or write. -type TxOptions interface { - // ReadOnly returns true if the transaction should be read only. - ReadOnly() bool -} - -// BatchedTx is a generic interface that represents the ability to execute -// several operations to a given storage interface in a single atomic -// transaction. Typically, Q here will be some subset of the main sqlc.Querier -// interface allowing it to only depend on the routines it needs to implement -// any additional business logic. -type BatchedTx[Q any] interface { - // ExecTx will execute the passed txBody, operating upon generic - // parameter Q (usually a storage interface) in a single transaction. - // The set of TxOptions are passed in in order to allow the caller to - // specify if a transaction should be read-only and optionally what - // type of concurrency control should be used. - ExecTx(ctx context.Context, txOptions TxOptions, - txBody func(Q) error) error - - // Backend returns the type of the database backend used. - Backend() sqlc.BackendType -} - -// Tx represents a database transaction that can be committed or rolled back. -type Tx interface { - // Commit commits the database transaction, an error should be returned - // if the commit isn't possible. - Commit() error - - // Rollback rolls back an incomplete database transaction. - // Transactions that were able to be committed can still call this as a - // noop. - Rollback() error -} - -// QueryCreator is a generic function that's used to create a Querier, which is -// a type of interface that implements storage related methods from a database -// transaction. This will be used to instantiate an object callers can use to -// apply multiple modifications to an object interface in a single atomic -// transaction. -type QueryCreator[Q any] func(*sql.Tx) Q - -// BatchedQuerier is a generic interface that allows callers to create a new -// database transaction based on an abstract type that implements the TxOptions -// interface. -type BatchedQuerier interface { - // Querier is the underlying query source, this is in place so we can - // pass a BatchedQuerier implementation directly into objects that - // create a batched version of the normal methods they need. - sqlc.Querier - - // CustomQueries is the set of custom queries that we have manually - // defined in addition to the ones generated by sqlc. - sqlc.CustomQueries - - // BeginTx creates a new database transaction given the set of - // transaction options. - BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error) -} - -// txExecutorOptions is a struct that holds the options for the transaction -// executor. This can be used to do things like retry a transaction due to an -// error a certain amount of times. -type txExecutorOptions struct { - numRetries int - initialRetryDelay time.Duration - maxRetryDelay time.Duration -} - -// defaultTxExecutorOptions returns the default options for the transaction -// executor. -func defaultTxExecutorOptions() *txExecutorOptions { - return &txExecutorOptions{ - numRetries: DefaultNumTxRetries, - initialRetryDelay: DefaultInitialRetryDelay, - maxRetryDelay: DefaultMaxRetryDelay, - } -} - -// randRetryDelay returns a random retry delay between -50% and +50% -// of the configured delay that is doubled for each attempt and capped at a max -// value. -func (t *txExecutorOptions) randRetryDelay(attempt int) time.Duration { - halfDelay := t.initialRetryDelay / 2 - randDelay := prand.Int63n(int64(t.initialRetryDelay)) //nolint:gosec - - // 50% plus 0%-100% gives us the range of 50%-150%. - initialDelay := halfDelay + time.Duration(randDelay) - - // If this is the first attempt, we just return the initial delay. - if attempt == 0 { - return initialDelay - } - - // For each subsequent delay, we double the initial delay. This still - // gives us a somewhat random delay, but it still increases with each - // attempt. If we double something n times, that's the same as - // multiplying the value with 2^n. We limit the power to 32 to avoid - // overflows. - factor := time.Duration(math.Pow(2, math.Min(float64(attempt), 32))) - actualDelay := initialDelay * factor - - // Cap the delay at the maximum configured value. - if actualDelay > t.maxRetryDelay { - return t.maxRetryDelay - } - - return actualDelay -} - -// TxExecutorOption is a functional option that allows us to pass in optional -// argument when creating the executor. -type TxExecutorOption func(*txExecutorOptions) - -// WithTxRetries is a functional option that allows us to specify the number of -// times a transaction should be retried if it fails with a repeatable error. -func WithTxRetries(numRetries int) TxExecutorOption { - return func(o *txExecutorOptions) { - o.numRetries = numRetries - } -} - -// WithTxRetryDelay is a functional option that allows us to specify the delay -// to wait before a transaction is retried. -func WithTxRetryDelay(delay time.Duration) TxExecutorOption { - return func(o *txExecutorOptions) { - o.initialRetryDelay = delay - } -} - -// TransactionExecutor is a generic struct that abstracts away from the type of -// query a type needs to run under a database transaction, and also the set of -// options for that transaction. The QueryCreator is used to create a query -// given a database transaction created by the BatchedQuerier. -type TransactionExecutor[Query any] struct { - BatchedQuerier - - createQuery QueryCreator[Query] - - opts *txExecutorOptions -} - -// NewTransactionExecutor creates a new instance of a TransactionExecutor given -// a Querier query object and a concrete type for the type of transactions the -// Querier understands. -func NewTransactionExecutor[Querier any](db BatchedQuerier, - createQuery QueryCreator[Querier], - opts ...TxExecutorOption) *TransactionExecutor[Querier] { - - txOpts := defaultTxExecutorOptions() - for _, optFunc := range opts { - optFunc(txOpts) - } - - return &TransactionExecutor[Querier]{ - BatchedQuerier: db, - createQuery: createQuery, - opts: txOpts, - } -} - -// ExecTx is a wrapper for txBody to abstract the creation and commit of a db -// transaction. The db transaction is embedded in a `*Queries` that txBody -// needs to use when executing each one of the queries that need to be applied -// atomically. This can be used by other storage interfaces to parameterize the -// type of query and options run, in order to have access to batched operations -// related to a storage object. -func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, - txOptions TxOptions, txBody func(Q) error) error { - - waitBeforeRetry := func(attemptNumber int) { - retryDelay := t.opts.randRetryDelay(attemptNumber) - - log.Tracef("Retrying transaction due to tx serialization or "+ - "deadlock error, attempt_number=%v, delay=%v", - attemptNumber, retryDelay) - - // Before we try again, we'll wait with a random backoff based - // on the retry delay. - time.Sleep(retryDelay) - } - - for i := 0; i < t.opts.numRetries; i++ { - // Create the db transaction. - tx, err := t.BatchedQuerier.BeginTx(ctx, txOptions) - if err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Nothing to roll back here, since we didn't - // even get a transaction yet. - waitBeforeRetry(i) - continue - } - - return dbErr - } - - // Rollback is safe to call even if the tx is already closed, - // so if the tx commits successfully, this is a no-op. - defer func() { - _ = tx.Rollback() - }() - - if err := txBody(t.createQuery(tx)); err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Roll back the transaction, then pop back up - // to try once again. - _ = tx.Rollback() - - waitBeforeRetry(i) - continue - } - - return dbErr - } - - // Commit transaction. - if err = tx.Commit(); err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Roll back the transaction, then pop back up - // to try once again. - _ = tx.Rollback() - - waitBeforeRetry(i) - continue - } - - return dbErr - } - - return nil - } - - // If we get to this point, then we weren't able to successfully commit - // a tx given the max number of retries. - return ErrRetriesExceeded -} - -// Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { - return t.BatchedQuerier.Backend() -} - -// BaseDB is the base database struct that each implementation can embed to -// gain some common functionality. -type BaseDB struct { - *sql.DB - - *sqlc.Queries -} - -// BeginTx wraps the normal sql specific BeginTx method with the TxOptions -// interface. This interface is then mapped to the concrete sql tx options -// struct. -func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { - sqlOptions := sql.TxOptions{ - ReadOnly: opts.ReadOnly(), - Isolation: sql.LevelSerializable, - } - return s.DB.BeginTx(ctx, &sqlOptions) -} - -// Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { - return s.Queries.Backend() -} - // QueriesTxOptions defines the set of db txn options the SQLQueries // understands. type QueriesTxOptions struct { diff --git a/db/migrations.go b/db/migrations.go index 79d63587e..204bf40f6 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -1,21 +1,5 @@ package db -import ( - "bytes" - "errors" - "fmt" - "io" - "io/fs" - "net/http" - "strings" - - "github.com/btcsuite/btclog/v2" - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database" - "github.com/golang-migrate/migrate/v4/source/httpfs" - "github.com/lightninglabs/taproot-assets/fn" -) - const ( // LatestMigrationVersion is the latest migration version of the // database. This is used to implement downgrade protection for the @@ -24,258 +8,3 @@ const ( // NOTE: This MUST be updated when a new migration is added. LatestMigrationVersion = 5 ) - -// MigrationTarget is a functional option that can be passed to applyMigrations -// to specify a target version to migrate to. `currentDbVersion` is the current -// (migration) version of the database, or None if unknown. -// `maxMigrationVersion` is the maximum migration version known to the driver, -// or None if unknown. -type MigrationTarget func(mig *migrate.Migrate, - currentDbVersion int, maxMigrationVersion uint) error - -var ( - // TargetLatest is a MigrationTarget that migrates to the latest - // version available. - TargetLatest = func(mig *migrate.Migrate, _ int, _ uint) error { - return mig.Up() - } - - // TargetVersion is a MigrationTarget that migrates to the given - // version. - TargetVersion = func(version uint) MigrationTarget { - return func(mig *migrate.Migrate, _ int, _ uint) error { - return mig.Migrate(version) - } - } -) - -var ( - // ErrMigrationDowngrade is returned when a database downgrade is - // detected. - ErrMigrationDowngrade = errors.New("database downgrade detected") -) - -// migrationOption is a functional option that can be passed to migrate related -// methods to modify their behavior. -type migrateOptions struct { - latestVersion fn.Option[uint] -} - -// defaultMigrateOptions returns a new migrateOptions instance with default -// settings. -func defaultMigrateOptions() *migrateOptions { - return &migrateOptions{} -} - -// MigrateOpt is a functional option that can be passed to migrate related -// methods to modify behavior. -type MigrateOpt func(*migrateOptions) - -// WithLatestVersion allows callers to override the default latest version -// setting. -func WithLatestVersion(version uint) MigrateOpt { - return func(o *migrateOptions) { - o.latestVersion = fn.Some(version) - } -} - -// migrationLogger is a logger that wraps the passed btclog.Logger so it can be -// used to log migrations. -type migrationLogger struct { - log btclog.Logger -} - -// Printf is like fmt.Printf. We map this to the target logger based on the -// current log level. -func (m *migrationLogger) Printf(format string, v ...interface{}) { - // Trim trailing newlines from the format. - format = strings.TrimRight(format, "\n") - - switch m.log.Level() { - case btclog.LevelTrace: - m.log.Tracef(format, v...) - case btclog.LevelDebug: - m.log.Debugf(format, v...) - case btclog.LevelInfo: - m.log.Infof(format, v...) - case btclog.LevelWarn: - m.log.Warnf(format, v...) - case btclog.LevelError: - m.log.Errorf(format, v...) - case btclog.LevelCritical: - m.log.Criticalf(format, v...) - case btclog.LevelOff: - } -} - -// Verbose should return true when verbose logging output is wanted -func (m *migrationLogger) Verbose() bool { - return m.log.Level() <= btclog.LevelDebug -} - -// applyMigrations executes database migration files found in the given file -// system under the given path, using the passed database driver and database -// name, up to or down to the given target version. -func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, - targetVersion MigrationTarget, opts *migrateOptions) error { - - // With the migrate instance open, we'll create a new migration source - // using the embedded file system stored in sqlSchemas. The library - // we're using can't handle a raw file system interface, so we wrap it - // in this intermediate layer. - migrateFileServer, err := httpfs.New(http.FS(fs), path) - if err != nil { - return err - } - - // Finally, we'll run the migration with our driver above based on the - // open DB, and also the migration source stored in the file system - // above. - sqlMigrate, err := migrate.NewWithInstance( - "migrations", migrateFileServer, dbName, driver, - ) - if err != nil { - return err - } - - migrationVersion, _, _ := sqlMigrate.Version() - - // As the down migrations may end up *dropping* data, we want to - // prevent that without explicit accounting. - latestVersion := opts.latestVersion.UnwrapOr(LatestMigrationVersion) - if migrationVersion > latestVersion { - return fmt.Errorf("%w: database version is newer than the "+ - "latest migration version, preventing downgrade: "+ - "db_version=%v, latest_migration_version=%v", - ErrMigrationDowngrade, migrationVersion, latestVersion) - } - - // Report the current version of the database before the migration. - currentDbVersion, _, err := driver.Version() - if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) - } - log.Infof("Attempting to apply migration(s) "+ - "(current_db_version=%v, latest_migration_version=%v)", - currentDbVersion, latestVersion) - - // Apply our local logger to the migration instance. - sqlMigrate.Log = &migrationLogger{log} - - // Execute the migration based on the target given. - err = targetVersion(sqlMigrate, currentDbVersion, latestVersion) - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - return err - } - - // Report the current version of the database after the migration. - currentDbVersion, _, err = driver.Version() - if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) - } - log.Infof("Database version after migration: %v", currentDbVersion) - - return nil -} - -// replacerFS is an implementation of a fs.FS virtual file system that wraps an -// existing file system but does a search-and-replace operation on each file -// when it is opened. -type replacerFS struct { - parentFS fs.FS - replaces map[string]string -} - -// A compile-time assertion to make sure replacerFS implements the fs.FS -// interface. -var _ fs.FS = (*replacerFS)(nil) - -// newReplacerFS creates a new replacer file system, wrapping the given parent -// virtual file system. Each file within the file system is undergoing a -// search-and-replace operation when it is opened, using the given map where the -// key denotes the search term and the value the term to replace each occurrence -// with. -func newReplacerFS(parent fs.FS, replaces map[string]string) *replacerFS { - return &replacerFS{ - parentFS: parent, - replaces: replaces, - } -} - -// Open opens a file in the virtual file system. -// -// NOTE: This is part of the fs.FS interface. -func (t *replacerFS) Open(name string) (fs.File, error) { - f, err := t.parentFS.Open(name) - if err != nil { - return nil, err - } - - stat, err := f.Stat() - if err != nil { - return nil, err - } - - if stat.IsDir() { - return f, err - } - - return newReplacerFile(f, t.replaces) -} - -type replacerFile struct { - parentFile fs.File - buf bytes.Buffer -} - -// A compile-time assertion to make sure replacerFile implements the fs.File -// interface. -var _ fs.File = (*replacerFile)(nil) - -func newReplacerFile(parent fs.File, replaces map[string]string) (*replacerFile, - error) { - - content, err := io.ReadAll(parent) - if err != nil { - return nil, err - } - - contentStr := string(content) - for from, to := range replaces { - contentStr = strings.ReplaceAll(contentStr, from, to) - } - - var buf bytes.Buffer - _, err = buf.WriteString(contentStr) - if err != nil { - return nil, err - } - - return &replacerFile{ - parentFile: parent, - buf: buf, - }, nil -} - -// Stat returns statistics/info about the file. -// -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Stat() (fs.FileInfo, error) { - return t.parentFile.Stat() -} - -// Read reads as many bytes as possible from the file into the given slice. -// -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Read(bytes []byte) (int, error) { - return t.buf.Read(bytes) -} - -// Close closes the underlying file. -// -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Close() error { - // We already fully read and then closed the file when creating this - // instance, so there's nothing to do for us here. - return nil -} diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..780a7a13d 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -1,26 +1,16 @@ package db import ( - "database/sql" "fmt" "testing" "time" - postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/lightninglabs/lightning-terminal/db/sqlc" - "github.com/stretchr/testify/require" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( dsnTemplate = "postgres://%v:%v@%v:%d/%v?sslmode=%v" - - // defaultMaxIdleConns is the number of permitted idle connections. - defaultMaxIdleConns = 6 - - // defaultConnMaxIdleTime is the amount of time a connection can be - // idle before it is closed. - defaultConnMaxIdleTime = 5 * time.Minute ) var ( @@ -30,16 +20,6 @@ var ( // fully executed yet. So this time needs to be chosen correctly to be // longer than the longest expected individual test run time. DefaultPostgresFixtureLifetime = 60 * time.Minute - - // postgresSchemaReplacements is a map of schema strings that need to be - // replaced for postgres. This is needed because we write the schemas - // to work with sqlite primarily, and postgres has some differences. - postgresSchemaReplacements = map[string]string{ - "BLOB": "BYTEA", - "INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY", - "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", - "UNHEX": "DECODE", - } ) // PostgresConfig holds the postgres database configuration. @@ -76,132 +56,17 @@ func (s *PostgresConfig) DSN(hidePassword bool) string { s.DBName, sslMode) } -// PostgresStore is a database store implementation that uses a Postgres -// backend. -type PostgresStore struct { - cfg *PostgresConfig - - *BaseDB -} - -// NewPostgresStore creates a new store that is backed by a Postgres database -// backend. -func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { - log.Infof("Using SQL database '%s'", cfg.DSN(true)) - - rawDb, err := sql.Open("pgx", cfg.DSN(false)) - if err != nil { - return nil, err - } - - maxConns := defaultMaxConns - if cfg.MaxOpenConnections > 0 { - maxConns = cfg.MaxOpenConnections - } - - maxIdleConns := defaultMaxIdleConns - if cfg.MaxIdleConnections > 0 { - maxIdleConns = cfg.MaxIdleConnections - } - - connMaxLifetime := defaultConnMaxLifetime - if cfg.ConnMaxLifetime > 0 { - connMaxLifetime = cfg.ConnMaxLifetime - } - - connMaxIdleTime := defaultConnMaxIdleTime - if cfg.ConnMaxIdleTime > 0 { - connMaxIdleTime = cfg.ConnMaxIdleTime - } - - rawDb.SetMaxOpenConns(maxConns) - rawDb.SetMaxIdleConns(maxIdleConns) - rawDb.SetConnMaxLifetime(connMaxLifetime) - rawDb.SetConnMaxIdleTime(connMaxIdleTime) - - queries := sqlc.NewPostgres(rawDb) - s := &PostgresStore{ - cfg: cfg, - BaseDB: &BaseDB{ - DB: rawDb, - Queries: queries, - }, - } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - - return s, nil -} - -// ExecuteMigrations runs migrations for the Postgres database, depending on the -// target given, either all migrations or up to a given version. -func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, - optFuncs ...MigrateOpt) error { - - opts := defaultMigrateOptions() - for _, optFunc := range optFuncs { - optFunc(opts) - } - - driver, err := postgres_migrate.WithInstance( - s.DB, &postgres_migrate.Config{}, - ) - if err != nil { - return fmt.Errorf("error creating postgres migration: %w", err) - } - - postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) - return applyMigrations( - postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, - opts, - ) -} - // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - - t.Cleanup(func() { - sqlFixture.TearDown(t) - }) - - return store -} - -// NewTestPostgresDBWithVersion is a helper function that creates a Postgres -// database for testing and migrates it to the given version. -func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore { - t.Helper() - - t.Logf("Creating new Postgres DB for testing, migrating to version %d", - version) - - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - storeCfg := sqlFixture.GetConfig() - storeCfg.SkipMigrations = true - store, err := NewPostgresStore(storeCfg) - require.NoError(t, err) - - err = store.ExecuteMigrations(TargetVersion(version)) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB(t, sqlFixture, LitdMigrationStreams) } diff --git a/db/sql_migrations.go b/db/sql_migrations.go new file mode 100644 index 000000000..57d283aa3 --- /dev/null +++ b/db/sql_migrations.go @@ -0,0 +1,30 @@ +package db + +import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +var ( + LitdMigrationStream = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: sqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + LitdMigrationStreams = []sqldb.MigrationStream{LitdMigrationStream} +) diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..4e831a074 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -1,49 +1,9 @@ package db import ( - "database/sql" - "fmt" - "net/url" - "path/filepath" - "testing" - "time" - - "github.com/golang-migrate/migrate/v4" - sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" - "github.com/lightninglabs/lightning-terminal/db/sqlc" - "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) -const ( - // sqliteOptionPrefix is the string prefix sqlite uses to set various - // options. This is used in the following format: - // * sqliteOptionPrefix || option_name = option_value. - sqliteOptionPrefix = "_pragma" - - // sqliteTxLockImmediate is a dsn option used to ensure that write - // transactions are started immediately. - sqliteTxLockImmediate = "_txlock=immediate" - - // defaultMaxConns is the number of permitted active and idle - // connections. We want to limit this so it isn't unlimited. We use the - // same value for the number of idle connections as, this can speed up - // queries given a new connection doesn't need to be established each - // time. - defaultMaxConns = 25 - - // defaultConnMaxLifetime is the maximum amount of time a connection can - // be reused for before it is closed. - defaultConnMaxLifetime = 10 * time.Minute -) - -var ( - // sqliteSchemaReplacements is a map of schema strings that need to be - // replaced for sqlite. There currently aren't any replacements, because - // the SQL files are written with SQLite compatibility in mind. - sqliteSchemaReplacements = map[string]string{} -) - // SqliteConfig holds all the config arguments needed to interact with our // sqlite DB. // @@ -61,248 +21,3 @@ type SqliteConfig struct { // found. DatabaseFileName string `long:"dbfile" description:"The full path to the database."` } - -// SqliteStore is a sqlite3 based database for the Taproot Asset daemon. -type SqliteStore struct { - cfg *SqliteConfig - - *BaseDB -} - -// NewSqliteStore attempts to open a new sqlite database based on the passed -// config. -func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { - // The set of pragma options are accepted using query options. For now - // we only want to ensure that foreign key constraints are properly - // enforced. - pragmaOptions := []struct { - name string - value string - }{ - { - name: "foreign_keys", - value: "on", - }, - { - name: "journal_mode", - value: "WAL", - }, - { - name: "busy_timeout", - value: "5000", - }, - { - // With the WAL mode, this ensures that we also do an - // extra WAL sync after each transaction. The normal - // sync mode skips this and gives better performance, - // but risks durability. - name: "synchronous", - value: "full", - }, - { - // This is used to ensure proper durability for users - // running on Mac OS. It uses the correct fsync system - // call to ensure items are fully flushed to disk. - name: "fullfsync", - value: "true", - }, - } - sqliteOptions := make(url.Values) - for _, option := range pragmaOptions { - sqliteOptions.Add( - sqliteOptionPrefix, - fmt.Sprintf("%v=%v", option.name, option.value), - ) - } - - // Construct the DSN which is just the database file name, appended - // with the series of pragma options as a query URL string. For more - // details on the formatting here, see the modernc.org/sqlite docs: - // https://pkg.go.dev/modernc.org/sqlite#Driver.Open. - dsn := fmt.Sprintf( - "%v?%v&%v", cfg.DatabaseFileName, sqliteOptions.Encode(), - sqliteTxLockImmediate, - ) - db, err := sql.Open("sqlite", dsn) - if err != nil { - return nil, err - } - - db.SetMaxOpenConns(defaultMaxConns) - db.SetMaxIdleConns(defaultMaxConns) - db.SetConnMaxLifetime(defaultConnMaxLifetime) - - queries := sqlc.NewSqlite(db) - s := &SqliteStore{ - cfg: cfg, - BaseDB: &BaseDB{ - DB: db, - Queries: queries, - }, - } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - - return s, nil -} - -// backupSqliteDatabase creates a backup of the given SQLite database. -func backupSqliteDatabase(srcDB *sql.DB, dbFullFilePath string) error { - if srcDB == nil { - return fmt.Errorf("backup source database is nil") - } - - // Create a database backup file full path from the given source - // database full file path. - // - // Get the current time and format it as a Unix timestamp in - // nanoseconds. - timestamp := time.Now().UnixNano() - - // Add the timestamp to the backup name. - backupFullFilePath := fmt.Sprintf( - "%s.%d.backup", dbFullFilePath, timestamp, - ) - - log.Infof("Creating backup of database file: %v -> %v", - dbFullFilePath, backupFullFilePath) - - // Create the database backup. - vacuumIntoQuery := "VACUUM INTO ?;" - stmt, err := srcDB.Prepare(vacuumIntoQuery) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(backupFullFilePath) - if err != nil { - return err - } - - return nil -} - -// backupAndMigrate is a helper function that creates a database backup before -// initiating the migration, and then migrates the database to the latest -// version. -func (s *SqliteStore) backupAndMigrate(mig *migrate.Migrate, - currentDbVersion int, maxMigrationVersion uint) error { - - // Determine if a database migration is necessary given the current - // database version and the maximum migration version. - versionUpgradePending := currentDbVersion < int(maxMigrationVersion) - if !versionUpgradePending { - log.Infof("Current database version is up-to-date, skipping "+ - "migration attempt and backup creation "+ - "(current_db_version=%v, max_migration_version=%v)", - currentDbVersion, maxMigrationVersion) - return nil - } - - // At this point, we know that a database migration is necessary. - // Create a backup of the database before starting the migration. - if !s.cfg.SkipMigrationDbBackup { - log.Infof("Creating database backup (before applying " + - "migration(s))") - - err := backupSqliteDatabase(s.DB, s.cfg.DatabaseFileName) - if err != nil { - return err - } - } else { - log.Infof("Skipping database backup creation before applying " + - "migration(s)") - } - - log.Infof("Applying migrations to database") - return mig.Up() -} - -// ExecuteMigrations runs migrations for the sqlite database, depending on the -// target given, either all migrations or up to a given version. -func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, - optFuncs ...MigrateOpt) error { - - opts := defaultMigrateOptions() - for _, optFunc := range optFuncs { - optFunc(opts) - } - - driver, err := sqlite_migrate.WithInstance( - s.DB, &sqlite_migrate.Config{}, - ) - if err != nil { - return fmt.Errorf("error creating sqlite migration: %w", err) - } - - sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) - return applyMigrations( - sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, - ) -} - -// NewTestSqliteDB is a helper function that creates an SQLite database for -// testing. -func NewTestSqliteDB(t *testing.T) *SqliteStore { - t.Helper() - - // TODO(roasbeef): if we pass :memory: for the file name, then we get - // an in mem version to speed up tests - dbPath := filepath.Join(t.TempDir(), "tmp.db") - t.Logf("Creating new SQLite DB handle for testing: %s", dbPath) - - return NewTestSqliteDbHandleFromPath(t, dbPath) -} - -// NewTestSqliteDbHandleFromPath is a helper function that creates a SQLite -// database handle given a database file path. -func NewTestSqliteDbHandleFromPath(t *testing.T, dbPath string) *SqliteStore { - t.Helper() - - sqlDB, err := NewSqliteStore(&SqliteConfig{ - DatabaseFileName: dbPath, - SkipMigrations: false, - }) - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, sqlDB.DB.Close()) - }) - - return sqlDB -} - -// NewTestSqliteDBWithVersion is a helper function that creates an SQLite -// database for testing and migrates it to the given version. -func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { - t.Helper() - - t.Logf("Creating new SQLite DB for testing, migrating to version %d", - version) - - // TODO(roasbeef): if we pass :memory: for the file name, then we get - // an in mem version to speed up tests - dbFileName := filepath.Join(t.TempDir(), "tmp.db") - sqlDB, err := NewSqliteStore(&SqliteConfig{ - DatabaseFileName: dbFileName, - SkipMigrations: true, - }) - require.NoError(t, err) - - err = sqlDB.ExecuteMigrations(TargetVersion(version)) - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, sqlDB.DB.Close()) - }) - - return sqlDB -} diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..4d5448313 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -12,7 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -167,7 +167,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +350,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 248892130..0c1847706 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -30,6 +31,7 @@ type SQLKVStoreQueries interface { UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlc.Kvstore, error) DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) @@ -45,7 +47,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index c2442a819..978329cc5 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -9,12 +9,11 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" ) @@ -54,7 +53,7 @@ func TestFirewallDBMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) @@ -63,19 +62,15 @@ func TestFirewallDBMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } // The assertMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + assertMigrationResults := func(t *testing.T, store *SQLDB, kvEntries []*kvEntry) { var ( @@ -88,7 +83,9 @@ func TestFirewallDBMigration(t *testing.T) { getRuleID := func(ruleName string) int64 { ruleID, ok := ruleIDs[ruleName] if !ok { - ruleID, err = sqlStore.GetRuleID(ctx, ruleName) + ruleID, err = store.db.GetRuleID( + ctx, ruleName, + ) require.NoError(t, err) ruleIDs[ruleName] = ruleID @@ -100,7 +97,7 @@ func TestFirewallDBMigration(t *testing.T) { getGroupID := func(groupAlias []byte) int64 { groupID, ok := groupIDs[string(groupAlias)] if !ok { - groupID, err = sqlStore.GetSessionIDByAlias( + groupID, err = store.db.GetSessionIDByAlias( ctx, groupAlias, ) require.NoError(t, err) @@ -114,7 +111,7 @@ func TestFirewallDBMigration(t *testing.T) { getFeatureID := func(featureName string) int64 { featureID, ok := featureIDs[featureName] if !ok { - featureID, err = sqlStore.GetFeatureID( + featureID, err = store.db.GetFeatureID( ctx, featureName, ) require.NoError(t, err) @@ -128,7 +125,7 @@ func TestFirewallDBMigration(t *testing.T) { // First we extract all migrated kv entries from the SQLDB, // in order to be able to compare them to the original kv // entries, to ensure that the migration was successful. - sqlKvEntries, err := sqlStore.ListAllKVStoresRecords(ctx) + sqlKvEntries, err := store.db.ListAllKVStoresRecords(ctx) require.NoError(t, err) require.Equal(t, len(kvEntries), len(sqlKvEntries)) @@ -144,7 +141,7 @@ func TestFirewallDBMigration(t *testing.T) { ruleID := getRuleID(entry.ruleName) if entry.groupAlias.IsNone() { - sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + sqlVal, err := store.db.GetGlobalKVStoreRecord( ctx, sqlc.GetGlobalKVStoreRecordParams{ Key: entry.key, @@ -162,7 +159,7 @@ func TestFirewallDBMigration(t *testing.T) { groupAlias := entry.groupAlias.UnwrapOrFail(t) groupID := getGroupID(groupAlias[:]) - v, err := sqlStore.GetGroupKVStoreRecord( + v, err := store.db.GetGroupKVStoreRecord( ctx, sqlc.GetGroupKVStoreRecordParams{ Key: entry.key, @@ -187,7 +184,7 @@ func TestFirewallDBMigration(t *testing.T) { entry.featureName.UnwrapOrFail(t), ) - sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + sqlVal, err := store.db.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: entry.key, @@ -296,12 +293,16 @@ func TestFirewallDBMigration(t *testing.T) { sqlStore, txEx := makeSQLDB(t, sessionsStore) // Perform the migration. - err = txEx.ExecTx(ctx, sqldb.WriteTxOpt(), + // + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, func(tx SQLQueries) error { return MigrateFirewallDBToSQL( ctx, firewallStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..1be887ace 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,9 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +20,20 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +43,31 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +78,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +111,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +127,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index a412441f8..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -55,8 +57,10 @@ func assertEqualActions(t *testing.T, expected, got *Action) { // createStore is a helper function that creates a new SQLDB and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - store := NewSQLDB(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 49b956d7d..ab184b5a6 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,23 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/go.mod b/go.mod index 65f98cb62..284f09e85 100644 --- a/go.mod +++ b/go.mod @@ -151,6 +151,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb // indirect github.com/lightningnetwork/lnd/healthcheck v1.2.6 // indirect github.com/lightningnetwork/lnd/queue v1.1.1 // indirect + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 // indirect github.com/lightningnetwork/lnd/ticker v1.1.1 // indirect github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -251,3 +252,8 @@ replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate // tapd wants v0.19.0-12, but loop can't handle that yet. So we'll just use the // previous version for now. replace github.com/lightninglabs/lndclient => github.com/lightninglabs/lndclient v0.19.0-11 + +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb diff --git a/go.sum b/go.sum index 4e5da46b6..8204eda1a 100644 --- a/go.sum +++ b/go.sum @@ -616,6 +616,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:Gbu github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c h1:2/XQEHo8QiuwA/TcOhYWuyet5AkxSYwITkC8KdHFs+8= +github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c/go.mod h1:54IwnYLMLlBwwzSMvNugIV81WAs4UEFxWvdFzfWwm9w= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -1178,8 +1180,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.8-0.20250716163904-2ef55ba74 github.com/lightninglabs/taproot-assets/taprpc v1.0.8-0.20250716163904-2ef55ba74036/go.mod h1:vOM2Ap2wYhEZjiJU7bNNg+e5tDxkvRAuyXwf/KQ4tgo= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.2-beta h1:3SKVrKYFY4IJLlrMf7cDzZcBeT+MxjI9Xy6YpY+EEX4= -github.com/lightningnetwork/lnd v0.19.2-beta/go.mod h1:+yKUfIGKKYRHGewgzQ6xi0S26DIfBiMv1zCdB3m6YxA= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= diff --git a/session/sql_migration_test.go b/session/sql_migration_test.go index 3c3b018e7..2b276bba8 100644 --- a/session/sql_migration_test.go +++ b/session/sql_migration_test.go @@ -2,17 +2,16 @@ package session import ( "context" - "database/sql" "fmt" "testing" "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "go.etcd.io/bbolt" "golang.org/x/exp/rand" @@ -38,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { // Create a sql store with a linked account store. testDBStore := NewTestDBWithAccounts(t, clock, acctStore) @@ -48,13 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } // assertMigrationResults asserts that the sql store contains the @@ -369,13 +364,16 @@ func TestSessionsStoreMigration(t *testing.T) { // migration. sqlStore, txEx := makeSQLDB(t, accountStore) + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, sqldb.WriteTxOpt(), + ctx, &opts, func(tx SQLQueries) error { return MigrateSessionStoreToSQL( ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..26662a574 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -14,6 +14,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +22,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -51,12 +54,13 @@ type SQLQueries interface { var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +70,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +303,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +347,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +380,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +439,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +450,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +480,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +513,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +543,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +598,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +639,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +688,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_sql.go b/session/test_sql.go index a83186069..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 0ceb0e046..84d946ce2 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) }