diff --git a/accounts/service.go b/accounts/service.go index 102b9ea84..697cf7518 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -361,7 +361,7 @@ func (s *InterceptorService) CreditAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Credit the account in the db. + // Credit the account in the DB. err := s.store.CreditAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to credit account: %w", err) @@ -386,7 +386,7 @@ func (s *InterceptorService) DebitAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Debit the account in the db. + // Debit the account in the DB. err := s.store.DebitAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to debit account: %w", err) diff --git a/accounts/service_test.go b/accounts/service_test.go index 1d4388664..f69d0fe93 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -246,7 +246,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was started successfully and // still running though, despite the closing of the - // db store. + // DB store. require.True(t, s.IsRunning()) // Now let's send the invoice update, which should fail. diff --git a/accounts/sql_migration.go b/accounts/sql_migration.go index c36b51c6f..94f88ac28 100644 --- a/accounts/sql_migration.go +++ b/accounts/sql_migration.go @@ -11,8 +11,11 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" "github.com/pmezard/go-difflib/difflib" ) @@ -27,7 +30,7 @@ var ( // the KV database to the SQL database. The migration is done in a single // transaction to ensure that all accounts are migrated or none at all. func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV accounts store to SQL") @@ -50,7 +53,7 @@ func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, // to the SQL database. The migration is done in a single transaction to ensure // that all accounts are migrated or none at all. func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts from KV to SQL") @@ -68,7 +71,7 @@ func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, kvAccount.ID, err) } - migratedAccount, err := getAndMarshalAccount( + migratedAccount, err := getAndMarshalMig6Account( ctx, tx, migratedAccountID, ) if err != nil { @@ -151,17 +154,79 @@ func getBBoltAccounts(db kvdb.Backend) ([]*OffChainBalanceAccount, error) { return accounts, nil } +// getAndMarshalAccount retrieves the account with the given ID. If the account +// cannot be found, then ErrAccNotFound is returned. +func getAndMarshalMig6Account(ctx context.Context, db SQLMig6Queries, + id int64) (*OffChainBalanceAccount, error) { + + dbAcct, err := db.GetAccount(ctx, id) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccNotFound + } else if err != nil { + return nil, err + } + + return marshalDBMig6Account(ctx, db, dbAcct) +} + +func marshalDBMig6Account(ctx context.Context, db SQLMig6Queries, + dbAcct sqlcmig6.Account) (*OffChainBalanceAccount, error) { + + alias, err := AccountIDFromInt64(dbAcct.Alias) + if err != nil { + return nil, err + } + + account := &OffChainBalanceAccount{ + ID: alias, + Type: AccountType(dbAcct.Type), + InitialBalance: lnwire.MilliSatoshi(dbAcct.InitialBalanceMsat), + CurrentBalance: dbAcct.CurrentBalanceMsat, + LastUpdate: dbAcct.LastUpdated.UTC(), + ExpirationDate: dbAcct.Expiration.UTC(), + Invoices: make(AccountInvoices), + Payments: make(AccountPayments), + Label: dbAcct.Label.String, + } + + invoices, err := db.ListAccountInvoices(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + for _, invoice := range invoices { + var hash lntypes.Hash + copy(hash[:], invoice.Hash) + account.Invoices[hash] = struct{}{} + } + + payments, err := db.ListAccountPayments(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + + for _, payment := range payments { + var hash lntypes.Hash + copy(hash[:], payment.Hash) + account.Payments[hash] = &PaymentEntry{ + Status: lnrpc.Payment_PaymentStatus(payment.Status), + FullAmount: lnwire.MilliSatoshi(payment.FullAmountMsat), + } + } + + return account, nil +} + // migrateSingleAccountToSQL runs the migration for a single account from the // KV database to the SQL database. func migrateSingleAccountToSQL(ctx context.Context, - tx SQLQueries, account *OffChainBalanceAccount) (int64, error) { + tx SQLMig6Queries, account *OffChainBalanceAccount) (int64, error) { accountAlias, err := account.ID.ToInt64() if err != nil { return 0, err } - insertAccountParams := sqlc.InsertAccountParams{ + insertAccountParams := sqlcmig6.InsertAccountParams{ Type: int16(account.Type), InitialBalanceMsat: int64(account.InitialBalance), CurrentBalanceMsat: account.CurrentBalance, @@ -180,7 +245,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash := range account.Invoices { - addInvoiceParams := sqlc.AddAccountInvoiceParams{ + addInvoiceParams := sqlcmig6.AddAccountInvoiceParams{ AccountID: sqlId, Hash: hash[:], } @@ -192,7 +257,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash, paymentEntry := range account.Payments { - upsertPaymentParams := sqlc.UpsertAccountPaymentParams{ + upsertPaymentParams := sqlcmig6.UpsertAccountPaymentParams{ AccountID: sqlId, Hash: hash[:], Status: int16(paymentEntry.Status), @@ -211,7 +276,7 @@ func migrateSingleAccountToSQL(ctx context.Context, // migrateAccountsIndicesToSQL runs the migration for the account indices from // the KV database to the SQL database. func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts indices from KV to SQL") @@ -233,7 +298,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, settleIndexName, settleIndex) } - setAddIndexParams := sqlc.SetAccountIndexParams{ + setAddIndexParams := sqlcmig6.SetAccountIndexParams{ Name: addIndexName, Value: int64(addIndex), } @@ -243,7 +308,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, return err } - setSettleIndexParams := sqlc.SetAccountIndexParams{ + setSettleIndexParams := sqlcmig6.SetAccountIndexParams{ Name: settleIndexName, Value: int64(settleIndex), } diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index b84485e1d..fb0cedca1 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { testDBStore := NewTestDB(t, clock) @@ -45,13 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -306,7 +301,7 @@ func TestAccountStoreMigration(t *testing.T) { ) require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, kvStore.db.Close()) + require.NoError(t, kvStore.DB.Close()) }) // Populate the kv store. @@ -337,13 +332,17 @@ func TestAccountStoreMigration(t *testing.T) { } // Perform the migration. + // + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, sqldb.WriteTxOpt(), - func(tx SQLQueries) error { + ctx, &opts, + func(tx SQLMig6Queries) error { return MigrateAccountStoreToSQL( - ctx, kvStore.db, tx, + ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) @@ -445,7 +444,7 @@ func rapidRandomizeAccounts(t *testing.T, kvStore *BoltStore) { acct := makeAccountGen().Draw(t, "account") // Then proceed to insert the account with its invoices and - // payments into the db + // payments into the DB newAcct, err := kvStore.NewAccount( ctx, acct.balance, acct.expiry, acct.label, ) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index a419017a8..c8f0282ee 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -24,7 +24,7 @@ import ( const ( // DBFilename is the filename within the data directory which contains // the macaroon stores. - DBFilename = "accounts.db" + DBFilename = "accounts.DB" // dbPathPermission is the default permission the account database // directory is created with (if it does not exist). @@ -60,7 +60,7 @@ var ( // BoltStore wraps the bolt DB that stores all accounts and their balances. type BoltStore struct { - db kvdb.Backend + DB kvdb.Backend clock clock.Clock } @@ -101,7 +101,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // Return the DB wrapped in a BoltStore object. return &BoltStore{ - db: db, + DB: db, clock: clock, }, nil } @@ -110,7 +110,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // // NOTE: This is part of the Store interface. func (s *BoltStore) Close() error { - return s.db.Close() + return s.DB.Close() } // NewAccount creates a new OffChainBalanceAccount with the given balance and a @@ -162,7 +162,7 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, // Try storing the account in the account database, so we can keep track // of its balance. - err := s.db.Update(func(tx walletdb.ReadWriteTx) error { + err := s.DB.Update(func(tx walletdb.ReadWriteTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -364,7 +364,7 @@ func (s *BoltStore) DeleteAccountPayment(_ context.Context, id AccountID, func (s *BoltStore) updateAccount(id AccountID, updateFn func(*OffChainBalanceAccount) error) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -451,7 +451,7 @@ func (s *BoltStore) Account(_ context.Context, id AccountID) ( // Try looking up and reading the account by its ID from the local // bolt DB. var accountBinary []byte - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -487,7 +487,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, error) { var accounts []*OffChainBalanceAccount - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { // This function will be called in the ForEach and receive // the key and value of each account in the DB. The key, which // is also the ID is not used because it is also marshaled into @@ -531,7 +531,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, // // NOTE: This is part of the Store interface. func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -554,7 +554,7 @@ func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte ) - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -592,7 +592,7 @@ func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, byteOrder.PutUint64(addValue, addIndex) byteOrder.PutUint64(settleValue, settleIndex) - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..13422315c 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -11,11 +11,13 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +35,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +57,40 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with accounts related tables. +// +//nolint:lll +type SQLMig6Queries interface { + sqldb.BaseQuerier + + AddAccountInvoice(ctx context.Context, arg sqlcmig6.AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg sqlcmig6.DeleteAccountPaymentParams) error + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountPayment(ctx context.Context, arg sqlcmig6.GetAccountPaymentParams) (sqlcmig6.AccountPayment, error) + InsertAccount(ctx context.Context, arg sqlcmig6.InsertAccountParams) (int64, error) + ListAccountInvoices(ctx context.Context, id int64) ([]sqlcmig6.AccountInvoice, error) + ListAccountPayments(ctx context.Context, id int64) ([]sqlcmig6.AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]sqlcmig6.Account, error) + SetAccountIndex(ctx context.Context, arg sqlcmig6.SetAccountIndexParams) error + UpdateAccountBalance(ctx context.Context, arg sqlcmig6.UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg sqlcmig6.UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg sqlcmig6.UpdateAccountLastUpdateParams) (int64, error) + UpsertAccountPayment(ctx context.Context, arg sqlcmig6.UpsertAccountPaymentParams) error + GetAccountInvoice(ctx context.Context, arg sqlcmig6.GetAccountInvoiceParams) (sqlcmig6.AccountInvoice, error) +} + +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +100,57 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) }, ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +227,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +369,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +447,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +482,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +523,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +545,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +577,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +594,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +704,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +747,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +774,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +799,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go index 546c1eee7..1d181928c 100644 --- a/accounts/test_kvdb.go +++ b/accounts/test_kvdb.go @@ -28,7 +28,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, store.db.Close()) + require.NoError(t, store.DB.Close()) }) return store diff --git a/accounts/test_sql.go b/accounts/test_sql.go index 3c1ee7f16..ca2f43d6f 100644 --- a/accounts/test_sql.go +++ b/accounts/test_sql.go @@ -5,15 +5,20 @@ package accounts import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 9d899b3e2..b1e8be871 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,11 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..30a1c54d7 100644 --- a/config_dev.go +++ b/config_dev.go @@ -3,14 +3,18 @@ package terminal import ( + "context" "fmt" "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -84,7 +88,9 @@ func defaultDevConfig() *DevConfig { } // NewStores creates a new stores instance based on the chosen database backend. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(ctx context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + var ( networkDir = filepath.Join(cfg.LitDir, cfg.Network) stores = &stores{ @@ -101,14 +107,39 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +147,44 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Postgres.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/config_prod.go b/config_prod.go index ac6e6d996..1f11507a6 100644 --- a/config_prod.go +++ b/config_prod.go @@ -3,6 +3,7 @@ package terminal import ( + "context" "fmt" "path/filepath" @@ -29,7 +30,9 @@ func (c *DevConfig) Validate(_, _ string) error { // NewStores creates a new instance of the stores struct using the default Bolt // backend since in production, this is currently the only backend supported. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(_ context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + networkDir := filepath.Join(cfg.LitDir, cfg.Network) stores := &stores{ diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..d8fd37ad2 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -1,310 +1,5 @@ package db -import ( - "context" - "database/sql" - "math" - prand "math/rand" - "time" - - "github.com/lightninglabs/lightning-terminal/db/sqlc" -) - -var ( - // DefaultStoreTimeout is the default timeout used for any interaction - // with the storage/database. - DefaultStoreTimeout = time.Second * 10 -) - -const ( - // DefaultNumTxRetries is the default number of times we'll retry a - // transaction if it fails with an error that permits transaction - // repetition. - DefaultNumTxRetries = 10 - - // DefaultInitialRetryDelay is the default initial delay between - // retries. This will be used to generate a random delay between -50% - // and +50% of this value, so 20 to 60 milliseconds. The retry will be - // doubled after each attempt until we reach DefaultMaxRetryDelay. We - // start with a random value to avoid multiple goroutines that are - // created at the same time to effectively retry at the same time. - DefaultInitialRetryDelay = time.Millisecond * 40 - - // DefaultMaxRetryDelay is the default maximum delay between retries. - DefaultMaxRetryDelay = time.Second * 3 -) - -// TxOptions represents a set of options one can use to control what type of -// database transaction is created. Transaction can wither be read or write. -type TxOptions interface { - // ReadOnly returns true if the transaction should be read only. - ReadOnly() bool -} - -// BatchedTx is a generic interface that represents the ability to execute -// several operations to a given storage interface in a single atomic -// transaction. Typically, Q here will be some subset of the main sqlc.Querier -// interface allowing it to only depend on the routines it needs to implement -// any additional business logic. -type BatchedTx[Q any] interface { - // ExecTx will execute the passed txBody, operating upon generic - // parameter Q (usually a storage interface) in a single transaction. - // The set of TxOptions are passed in in order to allow the caller to - // specify if a transaction should be read-only and optionally what - // type of concurrency control should be used. - ExecTx(ctx context.Context, txOptions TxOptions, - txBody func(Q) error) error - - // Backend returns the type of the database backend used. - Backend() sqlc.BackendType -} - -// Tx represents a database transaction that can be committed or rolled back. -type Tx interface { - // Commit commits the database transaction, an error should be returned - // if the commit isn't possible. - Commit() error - - // Rollback rolls back an incomplete database transaction. - // Transactions that were able to be committed can still call this as a - // noop. - Rollback() error -} - -// QueryCreator is a generic function that's used to create a Querier, which is -// a type of interface that implements storage related methods from a database -// transaction. This will be used to instantiate an object callers can use to -// apply multiple modifications to an object interface in a single atomic -// transaction. -type QueryCreator[Q any] func(*sql.Tx) Q - -// BatchedQuerier is a generic interface that allows callers to create a new -// database transaction based on an abstract type that implements the TxOptions -// interface. -type BatchedQuerier interface { - // Querier is the underlying query source, this is in place so we can - // pass a BatchedQuerier implementation directly into objects that - // create a batched version of the normal methods they need. - sqlc.Querier - - // CustomQueries is the set of custom queries that we have manually - // defined in addition to the ones generated by sqlc. - sqlc.CustomQueries - - // BeginTx creates a new database transaction given the set of - // transaction options. - BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error) -} - -// txExecutorOptions is a struct that holds the options for the transaction -// executor. This can be used to do things like retry a transaction due to an -// error a certain amount of times. -type txExecutorOptions struct { - numRetries int - initialRetryDelay time.Duration - maxRetryDelay time.Duration -} - -// defaultTxExecutorOptions returns the default options for the transaction -// executor. -func defaultTxExecutorOptions() *txExecutorOptions { - return &txExecutorOptions{ - numRetries: DefaultNumTxRetries, - initialRetryDelay: DefaultInitialRetryDelay, - maxRetryDelay: DefaultMaxRetryDelay, - } -} - -// randRetryDelay returns a random retry delay between -50% and +50% -// of the configured delay that is doubled for each attempt and capped at a max -// value. -func (t *txExecutorOptions) randRetryDelay(attempt int) time.Duration { - halfDelay := t.initialRetryDelay / 2 - randDelay := prand.Int63n(int64(t.initialRetryDelay)) //nolint:gosec - - // 50% plus 0%-100% gives us the range of 50%-150%. - initialDelay := halfDelay + time.Duration(randDelay) - - // If this is the first attempt, we just return the initial delay. - if attempt == 0 { - return initialDelay - } - - // For each subsequent delay, we double the initial delay. This still - // gives us a somewhat random delay, but it still increases with each - // attempt. If we double something n times, that's the same as - // multiplying the value with 2^n. We limit the power to 32 to avoid - // overflows. - factor := time.Duration(math.Pow(2, math.Min(float64(attempt), 32))) - actualDelay := initialDelay * factor - - // Cap the delay at the maximum configured value. - if actualDelay > t.maxRetryDelay { - return t.maxRetryDelay - } - - return actualDelay -} - -// TxExecutorOption is a functional option that allows us to pass in optional -// argument when creating the executor. -type TxExecutorOption func(*txExecutorOptions) - -// WithTxRetries is a functional option that allows us to specify the number of -// times a transaction should be retried if it fails with a repeatable error. -func WithTxRetries(numRetries int) TxExecutorOption { - return func(o *txExecutorOptions) { - o.numRetries = numRetries - } -} - -// WithTxRetryDelay is a functional option that allows us to specify the delay -// to wait before a transaction is retried. -func WithTxRetryDelay(delay time.Duration) TxExecutorOption { - return func(o *txExecutorOptions) { - o.initialRetryDelay = delay - } -} - -// TransactionExecutor is a generic struct that abstracts away from the type of -// query a type needs to run under a database transaction, and also the set of -// options for that transaction. The QueryCreator is used to create a query -// given a database transaction created by the BatchedQuerier. -type TransactionExecutor[Query any] struct { - BatchedQuerier - - createQuery QueryCreator[Query] - - opts *txExecutorOptions -} - -// NewTransactionExecutor creates a new instance of a TransactionExecutor given -// a Querier query object and a concrete type for the type of transactions the -// Querier understands. -func NewTransactionExecutor[Querier any](db BatchedQuerier, - createQuery QueryCreator[Querier], - opts ...TxExecutorOption) *TransactionExecutor[Querier] { - - txOpts := defaultTxExecutorOptions() - for _, optFunc := range opts { - optFunc(txOpts) - } - - return &TransactionExecutor[Querier]{ - BatchedQuerier: db, - createQuery: createQuery, - opts: txOpts, - } -} - -// ExecTx is a wrapper for txBody to abstract the creation and commit of a db -// transaction. The db transaction is embedded in a `*Queries` that txBody -// needs to use when executing each one of the queries that need to be applied -// atomically. This can be used by other storage interfaces to parameterize the -// type of query and options run, in order to have access to batched operations -// related to a storage object. -func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, - txOptions TxOptions, txBody func(Q) error) error { - - waitBeforeRetry := func(attemptNumber int) { - retryDelay := t.opts.randRetryDelay(attemptNumber) - - log.Tracef("Retrying transaction due to tx serialization or "+ - "deadlock error, attempt_number=%v, delay=%v", - attemptNumber, retryDelay) - - // Before we try again, we'll wait with a random backoff based - // on the retry delay. - time.Sleep(retryDelay) - } - - for i := 0; i < t.opts.numRetries; i++ { - // Create the db transaction. - tx, err := t.BatchedQuerier.BeginTx(ctx, txOptions) - if err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Nothing to roll back here, since we didn't - // even get a transaction yet. - waitBeforeRetry(i) - continue - } - - return dbErr - } - - // Rollback is safe to call even if the tx is already closed, - // so if the tx commits successfully, this is a no-op. - defer func() { - _ = tx.Rollback() - }() - - if err := txBody(t.createQuery(tx)); err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Roll back the transaction, then pop back up - // to try once again. - _ = tx.Rollback() - - waitBeforeRetry(i) - continue - } - - return dbErr - } - - // Commit transaction. - if err = tx.Commit(); err != nil { - dbErr := MapSQLError(err) - if IsSerializationOrDeadlockError(dbErr) { - // Roll back the transaction, then pop back up - // to try once again. - _ = tx.Rollback() - - waitBeforeRetry(i) - continue - } - - return dbErr - } - - return nil - } - - // If we get to this point, then we weren't able to successfully commit - // a tx given the max number of retries. - return ErrRetriesExceeded -} - -// Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { - return t.BatchedQuerier.Backend() -} - -// BaseDB is the base database struct that each implementation can embed to -// gain some common functionality. -type BaseDB struct { - *sql.DB - - *sqlc.Queries -} - -// BeginTx wraps the normal sql specific BeginTx method with the TxOptions -// interface. This interface is then mapped to the concrete sql tx options -// struct. -func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { - sqlOptions := sql.TxOptions{ - ReadOnly: opts.ReadOnly(), - Isolation: sql.LevelSerializable, - } - return s.DB.BeginTx(ctx, &sqlOptions) -} - -// Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { - return s.Queries.Backend() -} - // QueriesTxOptions defines the set of db txn options the SQLQueries // understands. type QueriesTxOptions struct { diff --git a/db/migrations.go b/db/migrations.go index 79d63587e..bbf35642f 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -1,19 +1,9 @@ package db import ( - "bytes" - "errors" - "fmt" - "io" - "io/fs" - "net/http" - "strings" - - "github.com/btcsuite/btclog/v2" "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database" - "github.com/golang-migrate/migrate/v4/source/httpfs" - "github.com/lightninglabs/taproot-assets/fn" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -23,259 +13,63 @@ const ( // // NOTE: This MUST be updated when a new migration is added. LatestMigrationVersion = 5 -) - -// MigrationTarget is a functional option that can be passed to applyMigrations -// to specify a target version to migrate to. `currentDbVersion` is the current -// (migration) version of the database, or None if unknown. -// `maxMigrationVersion` is the maximum migration version known to the driver, -// or None if unknown. -type MigrationTarget func(mig *migrate.Migrate, - currentDbVersion int, maxMigrationVersion uint) error - -var ( - // TargetLatest is a MigrationTarget that migrates to the latest - // version available. - TargetLatest = func(mig *migrate.Migrate, _ int, _ uint) error { - return mig.Up() - } - // TargetVersion is a MigrationTarget that migrates to the given - // version. - TargetVersion = func(version uint) MigrationTarget { - return func(mig *migrate.Migrate, _ int, _ uint) error { - return mig.Migrate(version) - } - } -) - -var ( - // ErrMigrationDowngrade is returned when a database downgrade is - // detected. - ErrMigrationDowngrade = errors.New("database downgrade detected") + // LatestDevMigrationVersion is the latest dev migration version of the + // database. This is used to implement downgrade protection for the + // daemon. This represents the latest number used in the migrations_dev + // directory. + // + // NOTE: This MUST be updated when a migration is added or removed, from + // the migrations_dev directory. + LatestDevMigrationVersion = 1 ) -// migrationOption is a functional option that can be passed to migrate related -// methods to modify their behavior. -type migrateOptions struct { - latestVersion fn.Option[uint] -} - -// defaultMigrateOptions returns a new migrateOptions instance with default -// settings. -func defaultMigrateOptions() *migrateOptions { - return &migrateOptions{} -} - -// MigrateOpt is a functional option that can be passed to migrate related -// methods to modify behavior. -type MigrateOpt func(*migrateOptions) - -// WithLatestVersion allows callers to override the default latest version -// setting. -func WithLatestVersion(version uint) MigrateOpt { - return func(o *migrateOptions) { - o.latestVersion = fn.Some(version) - } -} - -// migrationLogger is a logger that wraps the passed btclog.Logger so it can be -// used to log migrations. -type migrationLogger struct { - log btclog.Logger -} - -// Printf is like fmt.Printf. We map this to the target logger based on the -// current log level. -func (m *migrationLogger) Printf(format string, v ...interface{}) { - // Trim trailing newlines from the format. - format = strings.TrimRight(format, "\n") - - switch m.log.Level() { - case btclog.LevelTrace: - m.log.Tracef(format, v...) - case btclog.LevelDebug: - m.log.Debugf(format, v...) - case btclog.LevelInfo: - m.log.Infof(format, v...) - case btclog.LevelWarn: - m.log.Warnf(format, v...) - case btclog.LevelError: - m.log.Errorf(format, v...) - case btclog.LevelCritical: - m.log.Criticalf(format, v...) - case btclog.LevelOff: - } -} - -// Verbose should return true when verbose logging output is wanted -func (m *migrationLogger) Verbose() bool { - return m.log.Level() <= btclog.LevelDebug -} - -// applyMigrations executes database migration files found in the given file -// system under the given path, using the passed database driver and database -// name, up to or down to the given target version. -func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, - targetVersion MigrationTarget, opts *migrateOptions) error { - - // With the migrate instance open, we'll create a new migration source - // using the embedded file system stored in sqlSchemas. The library - // we're using can't handle a raw file system interface, so we wrap it - // in this intermediate layer. - migrateFileServer, err := httpfs.New(http.FS(fs), path) - if err != nil { - return err - } - - // Finally, we'll run the migration with our driver above based on the - // open DB, and also the migration source stored in the file system - // above. - sqlMigrate, err := migrate.NewWithInstance( - "migrations", migrateFileServer, dbName, driver, - ) - if err != nil { - return err - } - - migrationVersion, _, _ := sqlMigrate.Version() - - // As the down migrations may end up *dropping* data, we want to - // prevent that without explicit accounting. - latestVersion := opts.latestVersion.UnwrapOr(LatestMigrationVersion) - if migrationVersion > latestVersion { - return fmt.Errorf("%w: database version is newer than the "+ - "latest migration version, preventing downgrade: "+ - "db_version=%v, latest_migration_version=%v", - ErrMigrationDowngrade, migrationVersion, latestVersion) - } - - // Report the current version of the database before the migration. - currentDbVersion, _, err := driver.Version() - if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) - } - log.Infof("Attempting to apply migration(s) "+ - "(current_db_version=%v, latest_migration_version=%v)", - currentDbVersion, latestVersion) - - // Apply our local logger to the migration instance. - sqlMigrate.Log = &migrationLogger{log} - - // Execute the migration based on the target given. - err = targetVersion(sqlMigrate, currentDbVersion, latestVersion) - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - return err - } - - // Report the current version of the database after the migration. - currentDbVersion, _, err = driver.Version() - if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) - } - log.Infof("Database version after migration: %v", currentDbVersion) - - return nil -} - -// replacerFS is an implementation of a fs.FS virtual file system that wraps an -// existing file system but does a search-and-replace operation on each file -// when it is opened. -type replacerFS struct { - parentFS fs.FS - replaces map[string]string -} - -// A compile-time assertion to make sure replacerFS implements the fs.FS -// interface. -var _ fs.FS = (*replacerFS)(nil) - -// newReplacerFS creates a new replacer file system, wrapping the given parent -// virtual file system. Each file within the file system is undergoing a -// search-and-replace operation when it is opened, using the given map where the -// key denotes the search term and the value the term to replace each occurrence -// with. -func newReplacerFS(parent fs.FS, replaces map[string]string) *replacerFS { - return &replacerFS{ - parentFS: parent, - replaces: replaces, - } -} - -// Open opens a file in the virtual file system. -// -// NOTE: This is part of the fs.FS interface. -func (t *replacerFS) Open(name string) (fs.File, error) { - f, err := t.parentFS.Open(name) - if err != nil { - return nil, err - } - - stat, err := f.Stat() - if err != nil { - return nil, err - } - - if stat.IsDir() { - return f, err - } - - return newReplacerFile(f, t.replaces) -} - -type replacerFile struct { - parentFile fs.File - buf bytes.Buffer -} - -// A compile-time assertion to make sure replacerFile implements the fs.File -// interface. -var _ fs.File = (*replacerFile)(nil) - -func newReplacerFile(parent fs.File, replaces map[string]string) (*replacerFile, - error) { - - content, err := io.ReadAll(parent) - if err != nil { - return nil, err - } - - contentStr := string(content) - for from, to := range replaces { - contentStr = strings.ReplaceAll(contentStr, from, to) - } - - var buf bytes.Buffer - _, err = buf.WriteString(contentStr) - if err != nil { - return nil, err - } - - return &replacerFile{ - parentFile: parent, - buf: buf, - }, nil -} - -// Stat returns statistics/info about the file. -// -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Stat() (fs.FileInfo, error) { - return t.parentFile.Stat() -} - -// Read reads as many bytes as possible from the file into the given slice. -// -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Read(bytes []byte) (int, error) { - return t.buf.Read(bytes) -} - -// Close closes the underlying file. +// MakeTestMigrationStreams creates the migration streams for the unit test +// environment. // -// NOTE: This is part of the fs.File interface. -func (t *replacerFile) Close() error { - // We already fully read and then closed the file when creating this - // instance, so there's nothing to do for us here. - return nil +// NOTE: This function is not located in the migrationstreams package to avoid +// cyclic dependencies. This test migration stream does not run the kvdb to sql +// migration, as we already have separate unit tests which tests the migration. +func MakeTestMigrationStreams() []sqldb.MigrationStream { + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + migStreamDev := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + return []sqldb.MigrationStream{migStream, migStreamDev} } diff --git a/db/migrationstreams/log.go b/db/migrationstreams/log.go new file mode 100644 index 000000000..0a3e80731 --- /dev/null +++ b/db/migrationstreams/log.go @@ -0,0 +1,25 @@ +package migrationstreams + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +const Subsystem = "MIGS" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/db/migrationstreams/post_migration_callbacks_dev.go b/db/migrationstreams/post_migration_callbacks_dev.go new file mode 100644 index 000000000..7bfe5dbfb --- /dev/null +++ b/db/migrationstreams/post_migration_callbacks_dev.go @@ -0,0 +1,107 @@ +//go:build dev + +package migrationstreams + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightninglabs/lightning-terminal/firewalldb" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// MakePostStepCallbacksMig6 turns the post migration checks into a map of post +// step callbacks that can be used with the migrate package. The keys of the map +// are the migration versions, and the values are the callbacks that will be +// executed after the migration with the corresponding version is applied. +func MakePostStepCallbacksMig6(ctx context.Context, db *sqldb.BaseDB, + macPath string, clock clock.Clock, + migVersion uint) migrate.PostStepCallback { + + mig6queries := sqlcmig6.NewForType(db, db.BackendType) + mig6executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlcmig6.Queries { + return mig6queries.WithTx(tx) + }, + ) + + return func(_ *migrate.Migration, _ database.Driver) error { + // We ignore the actual driver that's being returned here, since + // we use migrate.NewWithInstance() to create the migration + // instance from our already instantiated database backend that + // is also passed into this function. + return mig6executor.ExecTx( + ctx, sqldb.NewWriteTx(), + func(q6 *sqlcmig6.Queries) error { + log.Infof("Running post migration callback "+ + "for migration version %d", migVersion) + + return kvdbToSqlMigrationCallback( + ctx, macPath, db, clock, q6, + ) + }, sqldb.NoOpReset, + ) + } +} + +func kvdbToSqlMigrationCallback(ctx context.Context, macPath string, + _ *sqldb.BaseDB, clock clock.Clock, q *sqlcmig6.Queries) error { + + start := time.Now() + log.Infof("Starting KVDB to SQL migration for all stores") + + accountStore, err := accounts.NewBoltStore( + filepath.Dir(macPath), accounts.DBFilename, clock, + ) + if err != nil { + return err + } + + err = accounts.MigrateAccountStoreToSQL(ctx, accountStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating account store to "+ + "SQL: %v", err) + } + + sessionStore, err := session.NewDB( + filepath.Dir(macPath), session.DBFilename, + clock, accountStore, + ) + if err != nil { + return err + } + + err = session.MigrateSessionStoreToSQL(ctx, sessionStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating session store to "+ + "SQL: %v", err) + } + + firewallStore, err := firewalldb.NewBoltDB( + filepath.Dir(macPath), firewalldb.DBFilename, + sessionStore, accountStore, clock, + ) + if err != nil { + return err + } + + err = firewalldb.MigrateFirewallDBToSQL(ctx, firewallStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating firewalldb store "+ + "to SQL: %v", err) + } + + log.Infof("Succesfully migrated all KVDB stores to SQL in: %v", + time.Since(start)) + + return nil +} diff --git a/db/migrationstreams/sql_migrations.go b/db/migrationstreams/sql_migrations.go new file mode 100644 index 000000000..41d6ccb88 --- /dev/null +++ b/db/migrationstreams/sql_migrations.go @@ -0,0 +1,40 @@ +//go:build !dev + +package migrationstreams + +import ( + "context" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// MakeMigrationStreams creates the migration streams for production +// environments. +func MakeMigrationStreams(_ context.Context, _ string, + _ clock.Clock) []sqldb.MigrationStream { + + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: db.LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + return []sqldb.MigrationStream{migStream} +} diff --git a/db/migrationstreams/sql_migrations_dev.go b/db/migrationstreams/sql_migrations_dev.go new file mode 100644 index 000000000..76c9dd096 --- /dev/null +++ b/db/migrationstreams/sql_migrations_dev.go @@ -0,0 +1,85 @@ +//go:build dev + +package migrationstreams + +import ( + "context" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +const ( + // KVDBtoSQLMigVersion is the version of the migration that migrates the + // kvdb to the sql database. + // + // TODO: When this the kvdb to sql migration goes live into prod, this + // should be moved to non dev db/migrations.go file, and this constant + // value should be updated to reflect the real migration number. + KVDBtoSQLMigVersion = 1 +) + +// MakeMigrationStreams creates the migration streams for the dev environments. +func MakeMigrationStreams(ctx context.Context, macPath string, + clock clock.Clock) []sqldb.MigrationStream { + + // Create the prod migration stream. + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: db.LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + // Create the dev migration stream. + migStreamDev := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: db.LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + // Any Callbacks added to this map will be executed when + // after the dev migration number for the uint key in + // the map has been applied. If no entry exists for a + // given uint, then no callback will be executed for + // that migration number. This is useful for adding a + // code migration step as a callback to be run + // after a specific migration of a given number has been + // applied. + res := make(map[uint]migrate.PostStepCallback) + + res[KVDBtoSQLMigVersion] = MakePostStepCallbacksMig6( + ctx, db, macPath, clock, KVDBtoSQLMigVersion, + ) + + return res, nil + }, + } + + return []sqldb.MigrationStream{migStream, migStreamDev} +} diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..21fc0edb1 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -1,26 +1,16 @@ package db import ( - "database/sql" "fmt" "testing" "time" - postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/lightninglabs/lightning-terminal/db/sqlc" - "github.com/stretchr/testify/require" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( dsnTemplate = "postgres://%v:%v@%v:%d/%v?sslmode=%v" - - // defaultMaxIdleConns is the number of permitted idle connections. - defaultMaxIdleConns = 6 - - // defaultConnMaxIdleTime is the amount of time a connection can be - // idle before it is closed. - defaultConnMaxIdleTime = 5 * time.Minute ) var ( @@ -30,16 +20,6 @@ var ( // fully executed yet. So this time needs to be chosen correctly to be // longer than the longest expected individual test run time. DefaultPostgresFixtureLifetime = 60 * time.Minute - - // postgresSchemaReplacements is a map of schema strings that need to be - // replaced for postgres. This is needed because we write the schemas - // to work with sqlite primarily, and postgres has some differences. - postgresSchemaReplacements = map[string]string{ - "BLOB": "BYTEA", - "INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY", - "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", - "UNHEX": "DECODE", - } ) // PostgresConfig holds the postgres database configuration. @@ -76,132 +56,17 @@ func (s *PostgresConfig) DSN(hidePassword bool) string { s.DBName, sslMode) } -// PostgresStore is a database store implementation that uses a Postgres -// backend. -type PostgresStore struct { - cfg *PostgresConfig - - *BaseDB -} - -// NewPostgresStore creates a new store that is backed by a Postgres database -// backend. -func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { - log.Infof("Using SQL database '%s'", cfg.DSN(true)) - - rawDb, err := sql.Open("pgx", cfg.DSN(false)) - if err != nil { - return nil, err - } - - maxConns := defaultMaxConns - if cfg.MaxOpenConnections > 0 { - maxConns = cfg.MaxOpenConnections - } - - maxIdleConns := defaultMaxIdleConns - if cfg.MaxIdleConnections > 0 { - maxIdleConns = cfg.MaxIdleConnections - } - - connMaxLifetime := defaultConnMaxLifetime - if cfg.ConnMaxLifetime > 0 { - connMaxLifetime = cfg.ConnMaxLifetime - } - - connMaxIdleTime := defaultConnMaxIdleTime - if cfg.ConnMaxIdleTime > 0 { - connMaxIdleTime = cfg.ConnMaxIdleTime - } - - rawDb.SetMaxOpenConns(maxConns) - rawDb.SetMaxIdleConns(maxIdleConns) - rawDb.SetConnMaxLifetime(connMaxLifetime) - rawDb.SetConnMaxIdleTime(connMaxIdleTime) - - queries := sqlc.NewPostgres(rawDb) - s := &PostgresStore{ - cfg: cfg, - BaseDB: &BaseDB{ - DB: rawDb, - Queries: queries, - }, - } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - - return s, nil -} - -// ExecuteMigrations runs migrations for the Postgres database, depending on the -// target given, either all migrations or up to a given version. -func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, - optFuncs ...MigrateOpt) error { - - opts := defaultMigrateOptions() - for _, optFunc := range optFuncs { - optFunc(opts) - } - - driver, err := postgres_migrate.WithInstance( - s.DB, &postgres_migrate.Config{}, - ) - if err != nil { - return fmt.Errorf("error creating postgres migration: %w", err) - } - - postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) - return applyMigrations( - postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, - opts, - ) -} - // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - - t.Cleanup(func() { - sqlFixture.TearDown(t) - }) - - return store -} - -// NewTestPostgresDBWithVersion is a helper function that creates a Postgres -// database for testing and migrates it to the given version. -func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore { - t.Helper() - - t.Logf("Creating new Postgres DB for testing, migrating to version %d", - version) - - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - storeCfg := sqlFixture.GetConfig() - storeCfg.SkipMigrations = true - store, err := NewPostgresStore(storeCfg) - require.NoError(t, err) - - err = store.ExecuteMigrations(TargetVersion(version)) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB(t, sqlFixture, MakeTestMigrationStreams()) } diff --git a/db/schemas.go b/db/schemas.go index 1a7a2096f..565fb5615 100644 --- a/db/schemas.go +++ b/db/schemas.go @@ -5,5 +5,5 @@ import ( _ "embed" ) -//go:embed sqlc/migrations/*.*.sql -var sqlSchemas embed.FS +//go:embed sqlc/migration*/*.*.sql +var SqlSchemas embed.FS diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/db/sqlcmig6/accounts.sql.go b/db/sqlcmig6/accounts.sql.go new file mode 100644 index 000000000..479c82b36 --- /dev/null +++ b/db/sqlcmig6/accounts.sql.go @@ -0,0 +1,390 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const addAccountInvoice = `-- name: AddAccountInvoice :exec +INSERT INTO account_invoices (account_id, hash) +VALUES ($1, $2) +` + +type AddAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error { + _, err := q.db.ExecContext(ctx, addAccountInvoice, arg.AccountID, arg.Hash) + return err +} + +const deleteAccount = `-- name: DeleteAccount :exec +DELETE FROM accounts +WHERE id = $1 +` + +func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAccount, id) + return err +} + +const deleteAccountPayment = `-- name: DeleteAccountPayment :exec +DELETE FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type DeleteAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, deleteAccountPayment, arg.Hash, arg.AccountID) + return err +} + +const getAccount = `-- name: GetAccount :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE id = $1 +` + +func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountByLabel = `-- name: GetAccountByLabel :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE label = $1 +` + +func (q *Queries) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccountByLabel, label) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountIDByAlias = `-- name: GetAccountIDByAlias :one +SELECT id +FROM accounts +WHERE alias = $1 +` + +func (q *Queries) GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getAccountIndex = `-- name: GetAccountIndex :one +SELECT value +FROM account_indices +WHERE name = $1 +` + +func (q *Queries) GetAccountIndex(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIndex, name) + var value int64 + err := row.Scan(&value) + return value, err +} + +const getAccountInvoice = `-- name: GetAccountInvoice :one +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 + AND hash = $2 +` + +type GetAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) { + row := q.db.QueryRowContext(ctx, getAccountInvoice, arg.AccountID, arg.Hash) + var i AccountInvoice + err := row.Scan(&i.AccountID, &i.Hash) + return i, err +} + +const getAccountPayment = `-- name: GetAccountPayment :one +SELECT account_id, hash, status, full_amount_msat FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type GetAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) { + row := q.db.QueryRowContext(ctx, getAccountPayment, arg.Hash, arg.AccountID) + var i AccountPayment + err := row.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ) + return i, err +} + +const insertAccount = `-- name: InsertAccount :one +INSERT INTO accounts (type, initial_balance_msat, current_balance_msat, last_updated, label, alias, expiration) +VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id +` + +type InsertAccountParams struct { + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Label sql.NullString + Alias int64 + Expiration time.Time +} + +func (q *Queries) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAccount, + arg.Type, + arg.InitialBalanceMsat, + arg.CurrentBalanceMsat, + arg.LastUpdated, + arg.Label, + arg.Alias, + arg.Expiration, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const listAccountInvoices = `-- name: ListAccountInvoices :many +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 +` + +func (q *Queries) ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) { + rows, err := q.db.QueryContext(ctx, listAccountInvoices, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountInvoice + for rows.Next() { + var i AccountInvoice + if err := rows.Scan(&i.AccountID, &i.Hash); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAccountPayments = `-- name: ListAccountPayments :many +SELECT account_id, hash, status, full_amount_msat +FROM account_payments +WHERE account_id = $1 +` + +func (q *Queries) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) { + rows, err := q.db.QueryContext(ctx, listAccountPayments, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountPayment + for rows.Next() { + var i AccountPayment + if err := rows.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAllAccounts = `-- name: ListAllAccounts :many +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +ORDER BY id +` + +func (q *Queries) ListAllAccounts(ctx context.Context) ([]Account, error) { + rows, err := q.db.QueryContext(ctx, listAllAccounts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Account + for rows.Next() { + var i Account + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setAccountIndex = `-- name: SetAccountIndex :exec +INSERT INTO account_indices (name, value) +VALUES ($1, $2) + ON CONFLICT (name) +DO UPDATE SET value = $2 +` + +type SetAccountIndexParams struct { + Name string + Value int64 +} + +func (q *Queries) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error { + _, err := q.db.ExecContext(ctx, setAccountIndex, arg.Name, arg.Value) + return err +} + +const updateAccountBalance = `-- name: UpdateAccountBalance :one +UPDATE accounts +SET current_balance_msat = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountBalanceParams struct { + CurrentBalanceMsat int64 + ID int64 +} + +func (q *Queries) UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountBalance, arg.CurrentBalanceMsat, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountExpiry = `-- name: UpdateAccountExpiry :one +UPDATE accounts +SET expiration = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountExpiryParams struct { + Expiration time.Time + ID int64 +} + +func (q *Queries) UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountExpiry, arg.Expiration, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountLastUpdate = `-- name: UpdateAccountLastUpdate :one +UPDATE accounts +SET last_updated = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountLastUpdateParams struct { + LastUpdated time.Time + ID int64 +} + +func (q *Queries) UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountLastUpdate, arg.LastUpdated, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertAccountPayment = `-- name: UpsertAccountPayment :exec +INSERT INTO account_payments (account_id, hash, status, full_amount_msat) +VALUES ($1, $2, $3, $4) +ON CONFLICT (account_id, hash) +DO UPDATE SET status = $3, full_amount_msat = $4 +` + +type UpsertAccountPaymentParams struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +func (q *Queries) UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, upsertAccountPayment, + arg.AccountID, + arg.Hash, + arg.Status, + arg.FullAmountMsat, + ) + return err +} diff --git a/db/sqlcmig6/actions.sql.go b/db/sqlcmig6/actions.sql.go new file mode 100644 index 000000000..a39d51e5d --- /dev/null +++ b/db/sqlcmig6/actions.sql.go @@ -0,0 +1,73 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const insertAction = `-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id +` + +type InsertActionParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +func (q *Queries) InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAction, + arg.SessionID, + arg.AccountID, + arg.MacaroonIdentifier, + arg.ActorName, + arg.FeatureName, + arg.ActionTrigger, + arg.Intent, + arg.StructuredJsonData, + arg.RpcMethod, + arg.RpcParamsJson, + arg.CreatedAt, + arg.ActionState, + arg.ErrorReason, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const setActionState = `-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3 +` + +type SetActionStateParams struct { + ActionState int16 + ErrorReason sql.NullString + ID int64 +} + +func (q *Queries) SetActionState(ctx context.Context, arg SetActionStateParams) error { + _, err := q.db.ExecContext(ctx, setActionState, arg.ActionState, arg.ErrorReason, arg.ID) + return err +} diff --git a/db/sqlcmig6/actions_custom.go b/db/sqlcmig6/actions_custom.go new file mode 100644 index 000000000..f01772d51 --- /dev/null +++ b/db/sqlcmig6/actions_custom.go @@ -0,0 +1,210 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "strconv" + "strings" +) + +// ActionQueryParams defines the parameters for querying actions. +type ActionQueryParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + FeatureName sql.NullString + ActorName sql.NullString + RpcMethod sql.NullString + State sql.NullInt16 + EndTime sql.NullTime + StartTime sql.NullTime + GroupID sql.NullInt64 +} + +// ListActionsParams defines the parameters for listing actions, including +// the ActionQueryParams for filtering and a Pagination struct for +// pagination. The Reversed field indicates whether the results should be +// returned in reverse order based on the created_at timestamp. +type ListActionsParams struct { + ActionQueryParams + Reversed bool + *Pagination +} + +// Pagination defines the pagination parameters for listing actions. +type Pagination struct { + NumOffset int32 + NumLimit int32 +} + +// ListActions retrieves a list of actions based on the provided +// ListActionsParams. +func (q *Queries) ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) { + + query, args := buildListActionsQuery(arg) + rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Action + for rows.Next() { + var i Action + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.AccountID, + &i.MacaroonIdentifier, + &i.ActorName, + &i.FeatureName, + &i.ActionTrigger, + &i.Intent, + &i.StructuredJsonData, + &i.RpcMethod, + &i.RpcParamsJson, + &i.CreatedAt, + &i.ActionState, + &i.ErrorReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// CountActions returns the number of actions that match the provided +// ActionQueryParams. +func (q *Queries) CountActions(ctx context.Context, + arg ActionQueryParams) (int64, error) { + + query, args := buildActionsQuery(arg, true) + row := q.db.QueryRowContext(ctx, query, args...) + + var count int64 + err := row.Scan(&count) + + return count, err +} + +// buildActionsQuery constructs a SQL query to retrieve actions based on the +// provided parameters. We do this manually so that if, for example, we have +// a sessionID we are filtering by, then this appears in the query as: +// `WHERE a.session_id = ?` which will properly make use of the underlying +// index. If we were instead to use a single SQLC query, it would include many +// WHERE clauses like: +// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)". +// This would use the index if run against postres but not when run against +// sqlite. +// +// The 'count' param indicates whether the query should return a count of +// actions that match the criteria or the actions themselves. +func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) { + var ( + conditions []string + args []any + ) + + if params.SessionID.Valid { + conditions = append(conditions, "a.session_id = ?") + args = append(args, params.SessionID.Int64) + } + if params.AccountID.Valid { + conditions = append(conditions, "a.account_id = ?") + args = append(args, params.AccountID.Int64) + } + if params.FeatureName.Valid { + conditions = append(conditions, "a.feature_name = ?") + args = append(args, params.FeatureName.String) + } + if params.ActorName.Valid { + conditions = append(conditions, "a.actor_name = ?") + args = append(args, params.ActorName.String) + } + if params.RpcMethod.Valid { + conditions = append(conditions, "a.rpc_method = ?") + args = append(args, params.RpcMethod.String) + } + if params.State.Valid { + conditions = append(conditions, "a.action_state = ?") + args = append(args, params.State.Int16) + } + if params.EndTime.Valid { + conditions = append(conditions, "a.created_at <= ?") + args = append(args, params.EndTime.Time) + } + if params.StartTime.Valid { + conditions = append(conditions, "a.created_at >= ?") + args = append(args, params.StartTime.Time) + } + if params.GroupID.Valid { + conditions = append(conditions, ` + EXISTS ( + SELECT 1 + FROM sessions s + WHERE s.id = a.session_id AND s.group_id = ? + )`) + args = append(args, params.GroupID.Int64) + } + + query := "SELECT a.* FROM actions a" + if count { + query = "SELECT COUNT(*) FROM actions a" + } + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + + return query, args +} + +// buildListActionsQuery constructs a SQL query to retrieve a list of actions +// based on the provided parameters. It builds upon the `buildActionsQuery` +// function, adding pagination and ordering based on the reversed parameter. +func buildListActionsQuery(params ListActionsParams) (string, []interface{}) { + query, args := buildActionsQuery(params.ActionQueryParams, false) + + // Determine order direction. + order := "ASC" + if params.Reversed { + order = "DESC" + } + query += " ORDER BY a.created_at " + order + + // Maybe paginate. + if params.Pagination != nil { + query += " LIMIT ? OFFSET ?" + args = append(args, params.NumLimit, params.NumOffset) + } + + return query, args +} + +// fillPlaceHolders replaces all '?' placeholders in the SQL query with +// positional placeholders like $1, $2, etc. This is necessary for +// compatibility with Postgres. +func fillPlaceHolders(query string) string { + var ( + sb strings.Builder + argNum = 1 + ) + + for i := range len(query) { + if query[i] != '?' { + sb.WriteByte(query[i]) + continue + } + + sb.WriteString("$") + sb.WriteString(strconv.Itoa(argNum)) + argNum++ + } + + return sb.String() +} diff --git a/db/sqlcmig6/db.go b/db/sqlcmig6/db.go new file mode 100644 index 000000000..82ff72dd8 --- /dev/null +++ b/db/sqlcmig6/db.go @@ -0,0 +1,27 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/db/sqlcmig6/db_custom.go b/db/sqlcmig6/db_custom.go new file mode 100644 index 000000000..e128cb290 --- /dev/null +++ b/db/sqlcmig6/db_custom.go @@ -0,0 +1,48 @@ +package sqlcmig6 + +import ( + "context" + + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// wrappedTX is a wrapper around a DBTX that also stores the database backend +// type. +type wrappedTX struct { + DBTX + + backendType sqldb.BackendType +} + +// Backend returns the type of database backend we're using. +func (q *Queries) Backend() sqldb.BackendType { + wtx, ok := q.db.(*wrappedTX) + if !ok { + // Shouldn't happen unless a new database backend type is added + // but not initialized correctly. + return sqldb.BackendTypeUnknown + } + + return wtx.backendType +} + +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} +} + +// CustomQueries defines a set of custom queries that we define in addition +// to the ones generated by sqlc. +type CustomQueries interface { + // CountActions returns the number of actions that match the provided + // ActionQueryParams. + CountActions(ctx context.Context, arg ActionQueryParams) (int64, error) + + // ListActions retrieves a list of actions based on the provided + // ListActionsParams. + ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) + + // Backend returns the type of the database backend used. + Backend() sqldb.BackendType +} diff --git a/db/sqlcmig6/kvstores.sql.go b/db/sqlcmig6/kvstores.sql.go new file mode 100644 index 000000000..3c076c5a4 --- /dev/null +++ b/db/sqlcmig6/kvstores.sql.go @@ -0,0 +1,376 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +const deleteAllTempKVStores = `-- name: DeleteAllTempKVStores :exec +DELETE FROM kvstores +WHERE perm = false +` + +func (q *Queries) DeleteAllTempKVStores(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteAllTempKVStores) + return err +} + +const deleteFeatureKVStoreRecord = `-- name: DeleteFeatureKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type DeleteFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const deleteGlobalKVStoreRecord = `-- name: DeleteGlobalKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type DeleteGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + return err +} + +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} + +const getFeatureID = `-- name: GetFeatureID :one +SELECT id +FROM features +WHERE name = $1 +` + +func (q *Queries) GetFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getFeatureKVStoreRecord = `-- name: GetFeatureKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type GetFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGlobalKVStoreRecord = `-- name: GetGlobalKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type GetGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one +INSERT INTO features (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getOrInsertRuleID = `-- name: GetOrInsertRuleID :one +INSERT INTO rules (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getRuleID = `-- name: GetRuleID :one +SELECT id +FROM rules +WHERE name = $1 +` + +func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) +VALUES ($1, $2, $3, $4, $5, $6) +` + +type InsertKVStoreRecordParams struct { + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, insertKVStoreRecord, + arg.Perm, + arg.RuleID, + arg.GroupID, + arg.FeatureID, + arg.EntryKey, + arg.Value, + ) + return err +} + +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.GroupID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id = $6 +` + +type UpdateFeatureKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateFeatureKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const updateGlobalKVStoreRecord = `-- name: UpdateGlobalKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id IS NULL + AND feature_id IS NULL +` + +type UpdateGlobalKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGlobalKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + ) + return err +} + +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id IS NULL +` + +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} diff --git a/db/sqlcmig6/models.go b/db/sqlcmig6/models.go new file mode 100644 index 000000000..9e57c28eb --- /dev/null +++ b/db/sqlcmig6/models.go @@ -0,0 +1,124 @@ +package sqlcmig6 + +import ( + "database/sql" + "time" +) + +type Account struct { + ID int64 + Alias int64 + Label sql.NullString + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Expiration time.Time +} + +type AccountIndex struct { + Name string + Value int64 +} + +type AccountInvoice struct { + AccountID int64 + Hash []byte +} + +type AccountPayment struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +type Action struct { + ID int64 + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +type Feature struct { + ID int64 + Name string +} + +type Kvstore struct { + ID int64 + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +type PrivacyPair struct { + GroupID int64 + RealVal string + PseudoVal string +} + +type Rule struct { + ID int64 + Name string +} + +type Session struct { + ID int64 + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + RevokedAt sql.NullTime + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + AccountID sql.NullInt64 + GroupID sql.NullInt64 +} + +type SessionFeatureConfig struct { + SessionID int64 + FeatureName string + Config []byte +} + +type SessionMacaroonCaveat struct { + ID int64 + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +type SessionMacaroonPermission struct { + ID int64 + SessionID int64 + Entity string + Action string +} + +type SessionPrivacyFlag struct { + SessionID int64 + Flag int32 +} diff --git a/db/sqlcmig6/privacy_paris.sql.go b/db/sqlcmig6/privacy_paris.sql.go new file mode 100644 index 000000000..20af09a2f --- /dev/null +++ b/db/sqlcmig6/privacy_paris.sql.go @@ -0,0 +1,91 @@ +package sqlcmig6 + +import ( + "context" +) + +const getAllPrivacyPairs = `-- name: GetAllPrivacyPairs :many +SELECT real_val, pseudo_val +FROM privacy_pairs +WHERE group_id = $1 +` + +type GetAllPrivacyPairsRow struct { + RealVal string + PseudoVal string +} + +func (q *Queries) GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) { + rows, err := q.db.QueryContext(ctx, getAllPrivacyPairs, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllPrivacyPairsRow + for rows.Next() { + var i GetAllPrivacyPairsRow + if err := rows.Scan(&i.RealVal, &i.PseudoVal); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getPseudoForReal = `-- name: GetPseudoForReal :one +SELECT pseudo_val +FROM privacy_pairs +WHERE group_id = $1 AND real_val = $2 +` + +type GetPseudoForRealParams struct { + GroupID int64 + RealVal string +} + +func (q *Queries) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) { + row := q.db.QueryRowContext(ctx, getPseudoForReal, arg.GroupID, arg.RealVal) + var pseudo_val string + err := row.Scan(&pseudo_val) + return pseudo_val, err +} + +const getRealForPseudo = `-- name: GetRealForPseudo :one +SELECT real_val +FROM privacy_pairs +WHERE group_id = $1 AND pseudo_val = $2 +` + +type GetRealForPseudoParams struct { + GroupID int64 + PseudoVal string +} + +func (q *Queries) GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) { + row := q.db.QueryRowContext(ctx, getRealForPseudo, arg.GroupID, arg.PseudoVal) + var real_val string + err := row.Scan(&real_val) + return real_val, err +} + +const insertPrivacyPair = `-- name: InsertPrivacyPair :exec +INSERT INTO privacy_pairs (group_id, real_val, pseudo_val) +VALUES ($1, $2, $3) +` + +type InsertPrivacyPairParams struct { + GroupID int64 + RealVal string + PseudoVal string +} + +func (q *Queries) InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error { + _, err := q.db.ExecContext(ctx, insertPrivacyPair, arg.GroupID, arg.RealVal, arg.PseudoVal) + return err +} diff --git a/db/sqlcmig6/querier.go b/db/sqlcmig6/querier.go new file mode 100644 index 000000000..57e229b5f --- /dev/null +++ b/db/sqlcmig6/querier.go @@ -0,0 +1,75 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type Querier interface { + AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error + DeleteAllTempKVStores(ctx context.Context) error + DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccount(ctx context.Context, id int64) (Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) + GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) + GetRuleID(ctx context.Context, name string) (int64, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) + GetSessionByID(ctx context.Context, id int64) (Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) + GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) + InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) + InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error + InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error + InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) + InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error + ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) + ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) + ListSessions(ctx context.Context) ([]Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]Session, error) + ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) + SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + SetActionState(ctx context.Context, arg SetActionStateParams) error + SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error + SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error + SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error + UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error + UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error + UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error +} + +var _ Querier = (*Queries)(nil) diff --git a/db/sqlcmig6/sessions.sql.go b/db/sqlcmig6/sessions.sql.go new file mode 100644 index 000000000..bc492043e --- /dev/null +++ b/db/sqlcmig6/sessions.sql.go @@ -0,0 +1,675 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const deleteSessionsWithState = `-- name: DeleteSessionsWithState :exec +DELETE FROM sessions +WHERE state = $1 +` + +func (q *Queries) DeleteSessionsWithState(ctx context.Context, state int16) error { + _, err := q.db.ExecContext(ctx, deleteSessionsWithState, state) + return err +} + +const getAliasBySessionID = `-- name: GetAliasBySessionID :one +SELECT alias FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getAliasBySessionID, id) + var alias []byte + err := row.Scan(&alias) + return alias, err +} + +const getSessionAliasesInGroup = `-- name: GetSessionAliasesInGroup :many +SELECT alias FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) { + rows, err := q.db.QueryContext(ctx, getSessionAliasesInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items [][]byte + for rows.Next() { + var alias []byte + if err := rows.Scan(&alias); err != nil { + return nil, err + } + items = append(items, alias) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionByAlias = `-- name: GetSessionByAlias :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByAlias, alias) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByID = `-- name: GetSessionByID :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetSessionByID(ctx context.Context, id int64) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByID, id) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByLocalPublicKey = `-- name: GetSessionByLocalPublicKey :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE local_public_key = $1 +` + +func (q *Queries) GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByLocalPublicKey, localPublicKey) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionFeatureConfigs = `-- name: GetSessionFeatureConfigs :many +SELECT session_id, feature_name, config FROM session_feature_configs +WHERE session_id = $1 +` + +func (q *Queries) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) { + rows, err := q.db.QueryContext(ctx, getSessionFeatureConfigs, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionFeatureConfig + for rows.Next() { + var i SessionFeatureConfig + if err := rows.Scan(&i.SessionID, &i.FeatureName, &i.Config); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionIDByAlias = `-- name: GetSessionIDByAlias :one +SELECT id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) { + row := q.db.QueryRowContext(ctx, getSessionIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getSessionMacaroonCaveats = `-- name: GetSessionMacaroonCaveats :many +SELECT id, session_id, caveat_id, verification_id, location FROM session_macaroon_caveats +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonCaveats, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonCaveat + for rows.Next() { + var i SessionMacaroonCaveat + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.CaveatID, + &i.VerificationID, + &i.Location, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionMacaroonPermissions = `-- name: GetSessionMacaroonPermissions :many +SELECT id, session_id, entity, action FROM session_macaroon_permissions +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonPermissions, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonPermission + for rows.Next() { + var i SessionMacaroonPermission + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Entity, + &i.Action, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionPrivacyFlags = `-- name: GetSessionPrivacyFlags :many +SELECT session_id, flag FROM session_privacy_flags +WHERE session_id = $1 +` + +func (q *Queries) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) { + rows, err := q.db.QueryContext(ctx, getSessionPrivacyFlags, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionPrivacyFlag + for rows.Next() { + var i SessionPrivacyFlag + if err := rows.Scan(&i.SessionID, &i.Flag); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionsInGroup = `-- name: GetSessionsInGroup :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, getSessionsInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertSession = `-- name: InsertSession :one +INSERT INTO sessions ( + alias, label, state, type, expiry, created_at, + server_address, dev_server, macaroon_root_key, pairing_secret, + local_private_key, local_public_key, remote_public_key, privacy, group_id, account_id +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, + $13, $14, $15, $16 +) RETURNING id +` + +type InsertSessionParams struct { + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + GroupID sql.NullInt64 + AccountID sql.NullInt64 +} + +func (q *Queries) InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertSession, + arg.Alias, + arg.Label, + arg.State, + arg.Type, + arg.Expiry, + arg.CreatedAt, + arg.ServerAddress, + arg.DevServer, + arg.MacaroonRootKey, + arg.PairingSecret, + arg.LocalPrivateKey, + arg.LocalPublicKey, + arg.RemotePublicKey, + arg.Privacy, + arg.GroupID, + arg.AccountID, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertSessionFeatureConfig = `-- name: InsertSessionFeatureConfig :exec +INSERT INTO session_feature_configs ( + session_id, feature_name, config +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionFeatureConfigParams struct { + SessionID int64 + FeatureName string + Config []byte +} + +func (q *Queries) InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error { + _, err := q.db.ExecContext(ctx, insertSessionFeatureConfig, arg.SessionID, arg.FeatureName, arg.Config) + return err +} + +const insertSessionMacaroonCaveat = `-- name: InsertSessionMacaroonCaveat :exec +INSERT INTO session_macaroon_caveats ( + session_id, caveat_id, verification_id, location +) VALUES ( + $1, $2, $3, $4 +) +` + +type InsertSessionMacaroonCaveatParams struct { + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +func (q *Queries) InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonCaveat, + arg.SessionID, + arg.CaveatID, + arg.VerificationID, + arg.Location, + ) + return err +} + +const insertSessionMacaroonPermission = `-- name: InsertSessionMacaroonPermission :exec +INSERT INTO session_macaroon_permissions ( + session_id, entity, action +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionMacaroonPermissionParams struct { + SessionID int64 + Entity string + Action string +} + +func (q *Queries) InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonPermission, arg.SessionID, arg.Entity, arg.Action) + return err +} + +const insertSessionPrivacyFlag = `-- name: InsertSessionPrivacyFlag :exec +INSERT INTO session_privacy_flags ( + session_id, flag +) VALUES ( + $1, $2 +) +` + +type InsertSessionPrivacyFlagParams struct { + SessionID int64 + Flag int32 +} + +func (q *Queries) InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error { + _, err := q.db.ExecContext(ctx, insertSessionPrivacyFlag, arg.SessionID, arg.Flag) + return err +} + +const listSessions = `-- name: ListSessions :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +ORDER BY created_at +` + +func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByState = `-- name: ListSessionsByState :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE state = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByState, state) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByType = `-- name: ListSessionsByType :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE type = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByType, type_) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setSessionGroupID = `-- name: SetSessionGroupID :exec +UPDATE sessions +SET group_id = $1 +WHERE id = $2 +` + +type SetSessionGroupIDParams struct { + GroupID sql.NullInt64 + ID int64 +} + +func (q *Queries) SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error { + _, err := q.db.ExecContext(ctx, setSessionGroupID, arg.GroupID, arg.ID) + return err +} + +const setSessionRemotePublicKey = `-- name: SetSessionRemotePublicKey :exec +UPDATE sessions +SET remote_public_key = $1 +WHERE id = $2 +` + +type SetSessionRemotePublicKeyParams struct { + RemotePublicKey []byte + ID int64 +} + +func (q *Queries) SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error { + _, err := q.db.ExecContext(ctx, setSessionRemotePublicKey, arg.RemotePublicKey, arg.ID) + return err +} + +const setSessionRevokedAt = `-- name: SetSessionRevokedAt :exec +UPDATE sessions +SET revoked_at = $1 +WHERE id = $2 +` + +type SetSessionRevokedAtParams struct { + RevokedAt sql.NullTime + ID int64 +} + +func (q *Queries) SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error { + _, err := q.db.ExecContext(ctx, setSessionRevokedAt, arg.RevokedAt, arg.ID) + return err +} + +const updateSessionState = `-- name: UpdateSessionState :exec +UPDATE sessions +SET state = $1 +WHERE id = $2 +` + +type UpdateSessionStateParams struct { + State int16 + ID int64 +} + +func (q *Queries) UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error { + _, err := q.db.ExecContext(ctx, updateSessionState, arg.State, arg.ID) + return err +} diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..4e831a074 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -1,49 +1,9 @@ package db import ( - "database/sql" - "fmt" - "net/url" - "path/filepath" - "testing" - "time" - - "github.com/golang-migrate/migrate/v4" - sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" - "github.com/lightninglabs/lightning-terminal/db/sqlc" - "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) -const ( - // sqliteOptionPrefix is the string prefix sqlite uses to set various - // options. This is used in the following format: - // * sqliteOptionPrefix || option_name = option_value. - sqliteOptionPrefix = "_pragma" - - // sqliteTxLockImmediate is a dsn option used to ensure that write - // transactions are started immediately. - sqliteTxLockImmediate = "_txlock=immediate" - - // defaultMaxConns is the number of permitted active and idle - // connections. We want to limit this so it isn't unlimited. We use the - // same value for the number of idle connections as, this can speed up - // queries given a new connection doesn't need to be established each - // time. - defaultMaxConns = 25 - - // defaultConnMaxLifetime is the maximum amount of time a connection can - // be reused for before it is closed. - defaultConnMaxLifetime = 10 * time.Minute -) - -var ( - // sqliteSchemaReplacements is a map of schema strings that need to be - // replaced for sqlite. There currently aren't any replacements, because - // the SQL files are written with SQLite compatibility in mind. - sqliteSchemaReplacements = map[string]string{} -) - // SqliteConfig holds all the config arguments needed to interact with our // sqlite DB. // @@ -61,248 +21,3 @@ type SqliteConfig struct { // found. DatabaseFileName string `long:"dbfile" description:"The full path to the database."` } - -// SqliteStore is a sqlite3 based database for the Taproot Asset daemon. -type SqliteStore struct { - cfg *SqliteConfig - - *BaseDB -} - -// NewSqliteStore attempts to open a new sqlite database based on the passed -// config. -func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { - // The set of pragma options are accepted using query options. For now - // we only want to ensure that foreign key constraints are properly - // enforced. - pragmaOptions := []struct { - name string - value string - }{ - { - name: "foreign_keys", - value: "on", - }, - { - name: "journal_mode", - value: "WAL", - }, - { - name: "busy_timeout", - value: "5000", - }, - { - // With the WAL mode, this ensures that we also do an - // extra WAL sync after each transaction. The normal - // sync mode skips this and gives better performance, - // but risks durability. - name: "synchronous", - value: "full", - }, - { - // This is used to ensure proper durability for users - // running on Mac OS. It uses the correct fsync system - // call to ensure items are fully flushed to disk. - name: "fullfsync", - value: "true", - }, - } - sqliteOptions := make(url.Values) - for _, option := range pragmaOptions { - sqliteOptions.Add( - sqliteOptionPrefix, - fmt.Sprintf("%v=%v", option.name, option.value), - ) - } - - // Construct the DSN which is just the database file name, appended - // with the series of pragma options as a query URL string. For more - // details on the formatting here, see the modernc.org/sqlite docs: - // https://pkg.go.dev/modernc.org/sqlite#Driver.Open. - dsn := fmt.Sprintf( - "%v?%v&%v", cfg.DatabaseFileName, sqliteOptions.Encode(), - sqliteTxLockImmediate, - ) - db, err := sql.Open("sqlite", dsn) - if err != nil { - return nil, err - } - - db.SetMaxOpenConns(defaultMaxConns) - db.SetMaxIdleConns(defaultMaxConns) - db.SetConnMaxLifetime(defaultConnMaxLifetime) - - queries := sqlc.NewSqlite(db) - s := &SqliteStore{ - cfg: cfg, - BaseDB: &BaseDB{ - DB: db, - Queries: queries, - }, - } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - - return s, nil -} - -// backupSqliteDatabase creates a backup of the given SQLite database. -func backupSqliteDatabase(srcDB *sql.DB, dbFullFilePath string) error { - if srcDB == nil { - return fmt.Errorf("backup source database is nil") - } - - // Create a database backup file full path from the given source - // database full file path. - // - // Get the current time and format it as a Unix timestamp in - // nanoseconds. - timestamp := time.Now().UnixNano() - - // Add the timestamp to the backup name. - backupFullFilePath := fmt.Sprintf( - "%s.%d.backup", dbFullFilePath, timestamp, - ) - - log.Infof("Creating backup of database file: %v -> %v", - dbFullFilePath, backupFullFilePath) - - // Create the database backup. - vacuumIntoQuery := "VACUUM INTO ?;" - stmt, err := srcDB.Prepare(vacuumIntoQuery) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(backupFullFilePath) - if err != nil { - return err - } - - return nil -} - -// backupAndMigrate is a helper function that creates a database backup before -// initiating the migration, and then migrates the database to the latest -// version. -func (s *SqliteStore) backupAndMigrate(mig *migrate.Migrate, - currentDbVersion int, maxMigrationVersion uint) error { - - // Determine if a database migration is necessary given the current - // database version and the maximum migration version. - versionUpgradePending := currentDbVersion < int(maxMigrationVersion) - if !versionUpgradePending { - log.Infof("Current database version is up-to-date, skipping "+ - "migration attempt and backup creation "+ - "(current_db_version=%v, max_migration_version=%v)", - currentDbVersion, maxMigrationVersion) - return nil - } - - // At this point, we know that a database migration is necessary. - // Create a backup of the database before starting the migration. - if !s.cfg.SkipMigrationDbBackup { - log.Infof("Creating database backup (before applying " + - "migration(s))") - - err := backupSqliteDatabase(s.DB, s.cfg.DatabaseFileName) - if err != nil { - return err - } - } else { - log.Infof("Skipping database backup creation before applying " + - "migration(s)") - } - - log.Infof("Applying migrations to database") - return mig.Up() -} - -// ExecuteMigrations runs migrations for the sqlite database, depending on the -// target given, either all migrations or up to a given version. -func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, - optFuncs ...MigrateOpt) error { - - opts := defaultMigrateOptions() - for _, optFunc := range optFuncs { - optFunc(opts) - } - - driver, err := sqlite_migrate.WithInstance( - s.DB, &sqlite_migrate.Config{}, - ) - if err != nil { - return fmt.Errorf("error creating sqlite migration: %w", err) - } - - sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) - return applyMigrations( - sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, - ) -} - -// NewTestSqliteDB is a helper function that creates an SQLite database for -// testing. -func NewTestSqliteDB(t *testing.T) *SqliteStore { - t.Helper() - - // TODO(roasbeef): if we pass :memory: for the file name, then we get - // an in mem version to speed up tests - dbPath := filepath.Join(t.TempDir(), "tmp.db") - t.Logf("Creating new SQLite DB handle for testing: %s", dbPath) - - return NewTestSqliteDbHandleFromPath(t, dbPath) -} - -// NewTestSqliteDbHandleFromPath is a helper function that creates a SQLite -// database handle given a database file path. -func NewTestSqliteDbHandleFromPath(t *testing.T, dbPath string) *SqliteStore { - t.Helper() - - sqlDB, err := NewSqliteStore(&SqliteConfig{ - DatabaseFileName: dbPath, - SkipMigrations: false, - }) - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, sqlDB.DB.Close()) - }) - - return sqlDB -} - -// NewTestSqliteDBWithVersion is a helper function that creates an SQLite -// database for testing and migrates it to the given version. -func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { - t.Helper() - - t.Logf("Creating new SQLite DB for testing, migrating to version %d", - version) - - // TODO(roasbeef): if we pass :memory: for the file name, then we get - // an in mem version to speed up tests - dbFileName := filepath.Join(t.TempDir(), "tmp.db") - sqlDB, err := NewSqliteStore(&SqliteConfig{ - DatabaseFileName: dbFileName, - SkipMigrations: true, - }) - require.NoError(t, err) - - err = sqlDB.ExecuteMigrations(TargetVersion(version)) - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, sqlDB.DB.Close()) - }) - - return sqlDB -} diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..eb0294af7 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -10,9 +10,10 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -22,6 +23,13 @@ type SQLAccountQueries interface { GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) } +// SQLMig6AccountQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the accounts table. +type SQLMig6AccountQueries interface { + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) +} + // SQLActionQueries is a subset of the sqlc.Queries interface that can be used // to interact with action related tables. // @@ -36,6 +44,20 @@ type SQLActionQueries interface { CountActions(ctx context.Context, arg sqlc.ActionQueryParams) (int64, error) } +// SQLMig6ActionQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with action related tables. +// +//nolint:lll +type SQLMig6ActionQueries interface { + SQLSessionQueries + SQLMig6AccountQueries + + InsertAction(ctx context.Context, arg sqlcmig6.InsertActionParams) (int64, error) + SetActionState(ctx context.Context, arg sqlcmig6.SetActionStateParams) error + ListActions(ctx context.Context, arg sqlcmig6.ListActionsParams) ([]sqlcmig6.Action, error) + CountActions(ctx context.Context, arg sqlcmig6.ActionQueryParams) (int64, error) +} + // sqlActionLocator helps us find an action in the SQL DB. type sqlActionLocator struct { // id is the DB level ID of the action. @@ -167,7 +189,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +224,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +372,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 248892130..96951db80 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -9,8 +9,10 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -30,6 +32,32 @@ type SQLKVStoreQueries interface { UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlc.Kvstore, error) + DeleteAllTempKVStores(ctx context.Context) error + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetRuleID(ctx context.Context, name string) (int64, error) +} + +// SQLMig6KVStoreQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the kvstore tables. +// +//nolint:lll +type SQLMig6KVStoreQueries interface { + SQLSessionQueries + + DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGroupKVStoreRecordParams) error + GetFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGroupKVStoreRecordParams) ([]byte, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGroupKVStoreRecordParams) error + InsertKVStoreRecord(ctx context.Context, arg sqlcmig6.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlcmig6.Kvstore, error) DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) @@ -45,7 +73,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the diff --git a/firewalldb/privacy_mapper_sql.go b/firewalldb/privacy_mapper_sql.go index 8a4863a6c..ff7f70f94 100644 --- a/firewalldb/privacy_mapper_sql.go +++ b/firewalldb/privacy_mapper_sql.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" ) @@ -22,6 +23,19 @@ type SQLPrivacyPairQueries interface { GetRealForPseudo(ctx context.Context, arg sqlc.GetRealForPseudoParams) (string, error) } +// SQLMig6PrivacyPairQueries is a subset of the sqlcmig6.Queries interface that +// can be used to interact with the privacy map table. +// +//nolint:lll +type SQLMig6PrivacyPairQueries interface { + SQLSessionQueries + + InsertPrivacyPair(ctx context.Context, arg sqlcmig6.InsertPrivacyPairParams) error + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]sqlcmig6.GetAllPrivacyPairsRow, error) + GetPseudoForReal(ctx context.Context, arg sqlcmig6.GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg sqlcmig6.GetRealForPseudoParams) (string, error) +} + // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. // diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go index 1e114c12c..f898993e3 100644 --- a/firewalldb/sql_migration.go +++ b/firewalldb/sql_migration.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb" "go.etcd.io/bbolt" @@ -78,7 +78,7 @@ func (e *kvEntry) namespacedKey() string { // NOTE: As sessions may contain linked sessions and accounts, the sessions and // accounts sql migration MUST be run prior to this migration. func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, - sqlTx SQLQueries) error { + sqlTx SQLMig6Queries) error { log.Infof("Starting migration of the rules DB to SQL") @@ -99,7 +99,7 @@ func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, // database to the SQL database. The function also asserts that the // migrated values match the original values in the KV store. func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, - sqlTx SQLQueries) error { + sqlTx SQLMig6Queries) error { log.Infof("Starting migration of the KV stores to SQL") @@ -361,7 +361,7 @@ func collectKVPairs(bkt *bbolt.Bucket, errorOnBuckets, perm bool, } // insertPair inserts a single key-value pair into the SQL database. -func insertPair(ctx context.Context, tx SQLQueries, +func insertPair(ctx context.Context, tx SQLMig6Queries, entry *kvEntry) (*sqlKvEntry, error) { ruleID, err := tx.GetOrInsertRuleID(ctx, entry.ruleName) @@ -369,7 +369,7 @@ func insertPair(ctx context.Context, tx SQLQueries, return nil, err } - p := sqlc.InsertKVStoreRecordParams{ + p := sqlcmig6.InsertKVStoreRecordParams{ Perm: entry.perm, RuleID: ruleID, EntryKey: entry.key, @@ -421,13 +421,13 @@ func insertPair(ctx context.Context, tx SQLQueries, // getSQLValue retrieves the key value for the given kvEntry from the SQL // database. -func getSQLValue(ctx context.Context, tx SQLQueries, +func getSQLValue(ctx context.Context, tx SQLMig6Queries, entry *sqlKvEntry) ([]byte, error) { switch { case entry.featureID.Valid && entry.groupID.Valid: return tx.GetFeatureKVStoreRecord( - ctx, sqlc.GetFeatureKVStoreRecordParams{ + ctx, sqlcmig6.GetFeatureKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, GroupID: entry.groupID, @@ -437,7 +437,7 @@ func getSQLValue(ctx context.Context, tx SQLQueries, ) case entry.groupID.Valid: return tx.GetGroupKVStoreRecord( - ctx, sqlc.GetGroupKVStoreRecordParams{ + ctx, sqlcmig6.GetGroupKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, GroupID: entry.groupID, @@ -446,7 +446,7 @@ func getSQLValue(ctx context.Context, tx SQLQueries, ) case !entry.featureID.Valid && !entry.groupID.Valid: return tx.GetGlobalKVStoreRecord( - ctx, sqlc.GetGlobalKVStoreRecordParams{ + ctx, sqlcmig6.GetGlobalKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, Key: entry.key, diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index c2442a819..139b45189 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -9,12 +9,11 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" ) @@ -53,29 +52,25 @@ func TestFirewallDBMigration(t *testing.T) { t.Skipf("Skipping Firewall DB migration test for kvdb build") } - makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, - *db.TransactionExecutor[SQLQueries]) { + makeSQLDB := func(t *testing.T, + sStore session.Store) *SQLMig6QueriesExecutor[SQLMig6Queries] { - testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + testDBStore := NewTestDBWithSessions(t, sStore, clock) store, ok := testDBStore.(*SQLDB) require.True(t, ok) baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return NewSQLMig6QueriesExecutor(baseDB, queries) } // The assertMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + assertMigrationResults := func(t *testing.T, store SQLMig6Queries, kvEntries []*kvEntry) { var ( @@ -88,7 +83,7 @@ func TestFirewallDBMigration(t *testing.T) { getRuleID := func(ruleName string) int64 { ruleID, ok := ruleIDs[ruleName] if !ok { - ruleID, err = sqlStore.GetRuleID(ctx, ruleName) + ruleID, err = store.GetRuleID(ctx, ruleName) require.NoError(t, err) ruleIDs[ruleName] = ruleID @@ -100,7 +95,7 @@ func TestFirewallDBMigration(t *testing.T) { getGroupID := func(groupAlias []byte) int64 { groupID, ok := groupIDs[string(groupAlias)] if !ok { - groupID, err = sqlStore.GetSessionIDByAlias( + groupID, err = store.GetSessionIDByAlias( ctx, groupAlias, ) require.NoError(t, err) @@ -114,7 +109,7 @@ func TestFirewallDBMigration(t *testing.T) { getFeatureID := func(featureName string) int64 { featureID, ok := featureIDs[featureName] if !ok { - featureID, err = sqlStore.GetFeatureID( + featureID, err = store.GetFeatureID( ctx, featureName, ) require.NoError(t, err) @@ -128,7 +123,7 @@ func TestFirewallDBMigration(t *testing.T) { // First we extract all migrated kv entries from the SQLDB, // in order to be able to compare them to the original kv // entries, to ensure that the migration was successful. - sqlKvEntries, err := sqlStore.ListAllKVStoresRecords(ctx) + sqlKvEntries, err := store.ListAllKVStoresRecords(ctx) require.NoError(t, err) require.Equal(t, len(kvEntries), len(sqlKvEntries)) @@ -144,9 +139,9 @@ func TestFirewallDBMigration(t *testing.T) { ruleID := getRuleID(entry.ruleName) if entry.groupAlias.IsNone() { - sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + sqlVal, err := store.GetGlobalKVStoreRecord( ctx, - sqlc.GetGlobalKVStoreRecordParams{ + sqlcmig6.GetGlobalKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -162,9 +157,9 @@ func TestFirewallDBMigration(t *testing.T) { groupAlias := entry.groupAlias.UnwrapOrFail(t) groupID := getGroupID(groupAlias[:]) - v, err := sqlStore.GetGroupKVStoreRecord( + v, err := store.GetGroupKVStoreRecord( ctx, - sqlc.GetGroupKVStoreRecordParams{ + sqlcmig6.GetGroupKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -187,9 +182,9 @@ func TestFirewallDBMigration(t *testing.T) { entry.featureName.UnwrapOrFail(t), ) - sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + sqlVal, err := store.GetFeatureKVStoreRecord( ctx, - sqlc.GetFeatureKVStoreRecordParams{ + sqlcmig6.GetFeatureKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -293,20 +288,29 @@ func TestFirewallDBMigration(t *testing.T) { // Create the SQL store that we will migrate the data // to. - sqlStore, txEx := makeSQLDB(t, sessionsStore) + txEx := makeSQLDB(t, sessionsStore) // Perform the migration. - err = txEx.ExecTx(ctx, sqldb.WriteTxOpt(), - func(tx SQLQueries) error { - return MigrateFirewallDBToSQL( + // + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLMig6Queries) error { + err = MigrateFirewallDBToSQL( ctx, firewallStore.DB, tx, ) - }, + if err != nil { + return err + } + + // Assert migration results. + assertMigrationResults(t, tx, entries) + + return nil + }, sqldb.NoOpReset, ) require.NoError(t, err) - - // Assert migration results. - assertMigrationResults(t, sqlStore, entries) }) } } diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..dcf201f1c 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,10 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +21,30 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with various firewalldb tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + SQLMig6KVStoreQueries + SQLMig6PrivacyPairQueries + SQLMig6ActionQueries +} + +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +54,51 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) + }, + ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +109,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +142,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +158,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index a412441f8..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -55,8 +57,10 @@ func assertEqualActions(t *testing.T, expected, got *Action) { // createStore is a helper function that creates a new SQLDB and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - store := NewSQLDB(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 49b956d7d..3f91546af 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,26 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/go.mod b/go.mod index 65f98cb62..284f09e85 100644 --- a/go.mod +++ b/go.mod @@ -151,6 +151,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb // indirect github.com/lightningnetwork/lnd/healthcheck v1.2.6 // indirect github.com/lightningnetwork/lnd/queue v1.1.1 // indirect + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 // indirect github.com/lightningnetwork/lnd/ticker v1.1.1 // indirect github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -251,3 +252,8 @@ replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate // tapd wants v0.19.0-12, but loop can't handle that yet. So we'll just use the // previous version for now. replace github.com/lightninglabs/lndclient => github.com/lightninglabs/lndclient v0.19.0-11 + +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb diff --git a/go.sum b/go.sum index 4e5da46b6..8204eda1a 100644 --- a/go.sum +++ b/go.sum @@ -616,6 +616,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:Gbu github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c h1:2/XQEHo8QiuwA/TcOhYWuyet5AkxSYwITkC8KdHFs+8= +github.com/ViktorTigerstrom/lnd v0.0.0-20250721232542-5e14695a410c/go.mod h1:54IwnYLMLlBwwzSMvNugIV81WAs4UEFxWvdFzfWwm9w= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -1178,8 +1180,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.8-0.20250716163904-2ef55ba74 github.com/lightninglabs/taproot-assets/taprpc v1.0.8-0.20250716163904-2ef55ba74036/go.mod h1:vOM2Ap2wYhEZjiJU7bNNg+e5tDxkvRAuyXwf/KQ4tgo= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.2-beta h1:3SKVrKYFY4IJLlrMf7cDzZcBeT+MxjI9Xy6YpY+EEX4= -github.com/lightningnetwork/lnd v0.19.2-beta/go.mod h1:+yKUfIGKKYRHGewgzQ6xi0S26DIfBiMv1zCdB3m6YxA= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= diff --git a/log.go b/log.go index b803f72d2..0535a819a 100644 --- a/log.go +++ b/log.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/firewall" "github.com/lightninglabs/lightning-terminal/firewalldb" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -91,6 +92,10 @@ func SetupLoggers(root *build.SubLoggerManager, intercept signal.Interceptor) { root, subservers.Subsystem, intercept, subservers.UseLogger, ) lnd.AddSubLogger(root, db.Subsystem, intercept, db.UseLogger) + lnd.AddSubLogger( + root, migrationstreams.Subsystem, intercept, + migrationstreams.UseLogger, + ) // Add daemon loggers to lnd's root logger. faraday.SetupLoggers(root, intercept) diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh index 16db97f2c..3d93f37ff 100755 --- a/scripts/gen_sqlc_docker.sh +++ b/scripts/gen_sqlc_docker.sh @@ -5,7 +5,7 @@ set -e # restore_files is a function to restore original schema files. restore_files() { echo "Restoring SQLite bigint patch..." - for file in db/sqlc/migrations/*.up.sql.bak; do + for file in db/sqlc/{migrations,migrations_dev}/*.up.sql.bak; do mv "$file" "${file%.bak}" done } @@ -30,7 +30,7 @@ GOMODCACHE=$(go env GOMODCACHE) # source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER # PRIMARY KEY". echo "Applying SQLite bigint patch..." -for file in db/sqlc/migrations/*.up.sql; do +for file in db/sqlc/{migrations,migrations_dev}/*.up.sql; do echo "Patching $file" sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file" done diff --git a/session/sql_migration.go b/session/sql_migration.go index 428cc0fce..b1caebeb4 100644 --- a/session/sql_migration.go +++ b/session/sql_migration.go @@ -9,12 +9,17 @@ import ( "reflect" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" + "github.com/lightninglabs/lightning-node-connect/mailbox" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb" "github.com/pmezard/go-difflib/difflib" "go.etcd.io/bbolt" + "gopkg.in/macaroon-bakery.v2/bakery" + "gopkg.in/macaroon.v2" ) var ( @@ -31,7 +36,7 @@ var ( // NOTE: As sessions may contain linked accounts, the accounts sql migration // MUST be run prior to this migration. func MigrateSessionStoreToSQL(ctx context.Context, kvStore *bbolt.DB, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV sessions store to SQL") @@ -118,7 +123,7 @@ func getBBoltSessions(db *bbolt.DB) ([]*Session, error) { // from the KV database to the SQL database, and validates that the migrated // sessions match the original sessions. func migrateSessionsToSQLAndValidate(ctx context.Context, - tx SQLQueries, kvSessions []*Session) error { + tx SQLMig6Queries, kvSessions []*Session) error { for _, kvSession := range kvSessions { err := migrateSingleSessionToSQL(ctx, tx, kvSession) @@ -127,18 +132,9 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, kvSession.ID, err) } - // Validate that the session was correctly migrated and matches - // the original session in the kv store. - sqlSess, err := tx.GetSessionByAlias(ctx, kvSession.ID[:]) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - err = ErrSessionNotFound - } - return fmt.Errorf("unable to get migrated session "+ - "from sql store: %w", err) - } - - migratedSession, err := unmarshalSession(ctx, tx, sqlSess) + migratedSession, err := getAndUnmarshalSession( + ctx, tx, kvSession.ID[:], + ) if err != nil { return fmt.Errorf("unable to unmarshal migrated "+ "session: %w", err) @@ -172,12 +168,206 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, return nil } +func getAndUnmarshalSession(ctx context.Context, + tx SQLMig6Queries, legacyID []byte) (*Session, error) { + + // Validate that the session was correctly migrated and matches + // the original session in the kv store. + sqlSess, err := tx.GetSessionByAlias(ctx, legacyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = ErrSessionNotFound + } + + return nil, fmt.Errorf("unable to get migrated session "+ + "from sql store: %w", err) + } + + migratedSession, err := unmarshalMig6Session(ctx, tx, sqlSess) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal migrated "+ + "session: %w", err) + } + + return migratedSession, nil + +} + +func unmarshalMig6Session(ctx context.Context, db SQLMig6Queries, + dbSess sqlcmig6.Session) (*Session, error) { + + var legacyGroupID ID + if dbSess.GroupID.Valid { + groupID, err := db.GetAliasBySessionID( + ctx, dbSess.GroupID.Int64, + ) + if err != nil { + return nil, fmt.Errorf("unable to get legacy group "+ + "Alias: %v", err) + } + + legacyGroupID, err = IDFromBytes(groupID) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", + err) + } + } + + var acctAlias fn.Option[accounts.AccountID] + if dbSess.AccountID.Valid { + account, err := db.GetAccount(ctx, dbSess.AccountID.Int64) + if err != nil { + return nil, fmt.Errorf("unable to get account: %v", err) + } + + accountAlias, err := accounts.AccountIDFromInt64(account.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get account ID: %v", err) + } + acctAlias = fn.Some(accountAlias) + } + + legacyID, err := IDFromBytes(dbSess.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", err) + } + + var revokedAt time.Time + if dbSess.RevokedAt.Valid { + revokedAt = dbSess.RevokedAt.Time + } + + localPriv, localPub := btcec.PrivKeyFromBytes(dbSess.LocalPrivateKey) + + var remotePub *btcec.PublicKey + if len(dbSess.RemotePublicKey) != 0 { + remotePub, err = btcec.ParsePubKey(dbSess.RemotePublicKey) + if err != nil { + return nil, fmt.Errorf("unable to parse remote "+ + "public key: %v", err) + } + } + + // Get the macaroon permissions if they exist. + perms, err := db.GetSessionMacaroonPermissions(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "permissions: %v", err) + } + + // Get the macaroon caveats if they exist. + caveats, err := db.GetSessionMacaroonCaveats(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "caveats: %v", err) + } + + var macRecipe *MacaroonRecipe + if perms != nil || caveats != nil { + macRecipe = &MacaroonRecipe{ + Permissions: unmarshalMig6MacPerms(perms), + Caveats: unmarshalMig6MacCaveats(caveats), + } + } + + // Get the feature configs if they exist. + featureConfigs, err := db.GetSessionFeatureConfigs(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get feature configs: %v", err) + } + + var featureCfgs *FeaturesConfig + if featureConfigs != nil { + featureCfgs = unmarshalMig6FeatureConfigs(featureConfigs) + } + + // Get the privacy flags if they exist. + privacyFlags, err := db.GetSessionPrivacyFlags(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get privacy flags: %v", err) + } + + var privFlags PrivacyFlags + if privacyFlags != nil { + privFlags = unmarshalMig6PrivacyFlags(privacyFlags) + } + + var pairingSecret [mailbox.NumPassphraseEntropyBytes]byte + copy(pairingSecret[:], dbSess.PairingSecret) + + return &Session{ + ID: legacyID, + Label: dbSess.Label, + State: State(dbSess.State), + Type: Type(dbSess.Type), + Expiry: dbSess.Expiry, + CreatedAt: dbSess.CreatedAt, + RevokedAt: revokedAt, + ServerAddr: dbSess.ServerAddress, + DevServer: dbSess.DevServer, + MacaroonRootKey: uint64(dbSess.MacaroonRootKey), + PairingSecret: pairingSecret, + LocalPrivateKey: localPriv, + LocalPublicKey: localPub, + RemotePublicKey: remotePub, + WithPrivacyMapper: dbSess.Privacy, + GroupID: legacyGroupID, + PrivacyFlags: privFlags, + MacaroonRecipe: macRecipe, + FeatureConfig: featureCfgs, + AccountID: acctAlias, + }, nil +} + +func unmarshalMig6MacPerms(dbPerms []sqlcmig6.SessionMacaroonPermission) []bakery.Op { + ops := make([]bakery.Op, len(dbPerms)) + for i, dbPerm := range dbPerms { + ops[i] = bakery.Op{ + Entity: dbPerm.Entity, + Action: dbPerm.Action, + } + } + + return ops +} + +func unmarshalMig6MacCaveats(dbCaveats []sqlcmig6.SessionMacaroonCaveat) []macaroon.Caveat { + caveats := make([]macaroon.Caveat, len(dbCaveats)) + for i, dbCaveat := range dbCaveats { + caveats[i] = macaroon.Caveat{ + Id: dbCaveat.CaveatID, + VerificationId: dbCaveat.VerificationID, + Location: dbCaveat.Location.String, + } + } + + return caveats +} + +func unmarshalMig6FeatureConfigs(dbConfigs []sqlcmig6.SessionFeatureConfig) *FeaturesConfig { + configs := make(FeaturesConfig, len(dbConfigs)) + for _, dbConfig := range dbConfigs { + configs[dbConfig.FeatureName] = dbConfig.Config + } + + return &configs +} + +func unmarshalMig6PrivacyFlags(dbFlags []sqlcmig6.SessionPrivacyFlag) PrivacyFlags { + flags := make(PrivacyFlags, len(dbFlags)) + for i, dbFlag := range dbFlags { + flags[i] = PrivacyFlag(dbFlag.Flag) + } + + return flags +} + // migrateSingleSessionToSQL runs the migration for a single session from the // KV database to the SQL database. Note that if the session links to an // account, the linked accounts store MUST have been migrated before that // session is migrated. func migrateSingleSessionToSQL(ctx context.Context, - tx SQLQueries, session *Session) error { + tx SQLMig6Queries, session *Session) error { var ( acctID sql.NullInt64 @@ -213,7 +403,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Proceed to insert the session into the sql db. - sqlId, err := tx.InsertSession(ctx, sqlc.InsertSessionParams{ + sqlId, err := tx.InsertSession(ctx, sqlcmig6.InsertSessionParams{ Alias: session.ID[:], Label: session.Label, State: int16(session.State), @@ -239,7 +429,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // has been created. if !session.RevokedAt.IsZero() { err = tx.SetSessionRevokedAt( - ctx, sqlc.SetSessionRevokedAtParams{ + ctx, sqlcmig6.SetSessionRevokedAtParams{ ID: sqlId, RevokedAt: sqldb.SQLTime( session.RevokedAt.UTC(), @@ -265,7 +455,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Now lets set the group ID for the session. - err = tx.SetSessionGroupID(ctx, sqlc.SetSessionGroupIDParams{ + err = tx.SetSessionGroupID(ctx, sqlcmig6.SetSessionGroupIDParams{ ID: sqlId, GroupID: sqldb.SQLInt64(groupID), }) @@ -279,7 +469,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // We start by inserting the macaroon permissions. for _, sessionPerm := range session.MacaroonRecipe.Permissions { err = tx.InsertSessionMacaroonPermission( - ctx, sqlc.InsertSessionMacaroonPermissionParams{ + ctx, sqlcmig6.InsertSessionMacaroonPermissionParams{ SessionID: sqlId, Entity: sessionPerm.Entity, Action: sessionPerm.Action, @@ -293,7 +483,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Next we insert the macaroon caveats. for _, caveat := range session.MacaroonRecipe.Caveats { err = tx.InsertSessionMacaroonCaveat( - ctx, sqlc.InsertSessionMacaroonCaveatParams{ + ctx, sqlcmig6.InsertSessionMacaroonCaveatParams{ SessionID: sqlId, CaveatID: caveat.Id, VerificationID: caveat.VerificationId, @@ -312,7 +502,7 @@ func migrateSingleSessionToSQL(ctx context.Context, if session.FeatureConfig != nil { for featureName, config := range *session.FeatureConfig { err = tx.InsertSessionFeatureConfig( - ctx, sqlc.InsertSessionFeatureConfigParams{ + ctx, sqlcmig6.InsertSessionFeatureConfigParams{ SessionID: sqlId, FeatureName: featureName, Config: config, @@ -327,7 +517,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Finally we insert the privacy flags. for _, privacyFlag := range session.PrivacyFlags { err = tx.InsertSessionPrivacyFlag( - ctx, sqlc.InsertSessionPrivacyFlagParams{ + ctx, sqlcmig6.InsertSessionPrivacyFlagParams{ SessionID: sqlId, Flag: int32(privacyFlag), }, diff --git a/session/sql_migration_test.go b/session/sql_migration_test.go index 3c3b018e7..508a2607a 100644 --- a/session/sql_migration_test.go +++ b/session/sql_migration_test.go @@ -2,17 +2,16 @@ package session import ( "context" - "database/sql" "fmt" "testing" "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "go.etcd.io/bbolt" "golang.org/x/exp/rand" @@ -38,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { // Create a sql store with a linked account store. testDBStore := NewTestDBWithAccounts(t, clock, acctStore) @@ -48,13 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } // assertMigrationResults asserts that the sql store contains the @@ -369,13 +364,16 @@ func TestSessionsStoreMigration(t *testing.T) { // migration. sqlStore, txEx := makeSQLDB(t, accountStore) + // TODO(viktor): remove sqldb.MigrationTxOptions once + // sqldb v2 is based on the latest version of lnd/sqldb. + var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, sqldb.WriteTxOpt(), - func(tx SQLQueries) error { + ctx, &opts, + func(tx SQLMig6Queries) error { return MigrateSessionStoreToSQL( ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..a169dff9b 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -12,8 +12,10 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +23,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -49,14 +53,48 @@ type SQLQueries interface { GetAccount(ctx context.Context, id int64) (sqlc.Account, error) } +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used to +// interact with session related tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetSessionByID(ctx context.Context, id int64) (sqlcmig6.Session, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlcmig6.Session, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, legacyID []byte) (sqlcmig6.Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (sqlcmig6.Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionFeatureConfig, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonCaveat, error) + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionPrivacyFlag, error) + InsertSessionFeatureConfig(ctx context.Context, arg sqlcmig6.InsertSessionFeatureConfigParams) error + SetSessionRevokedAt(ctx context.Context, arg sqlcmig6.SetSessionRevokedAtParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg sqlcmig6.InsertSessionPrivacyFlagParams) error + InsertSession(ctx context.Context, arg sqlcmig6.InsertSessionParams) (int64, error) + ListSessions(ctx context.Context) ([]sqlcmig6.Session, error) + ListSessionsByType(ctx context.Context, sessionType int16) ([]sqlcmig6.Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]sqlcmig6.Session, error) + SetSessionRemotePublicKey(ctx context.Context, arg sqlcmig6.SetSessionRemotePublicKeyParams) error + SetSessionGroupID(ctx context.Context, arg sqlcmig6.SetSessionGroupIDParams) error + UpdateSessionState(ctx context.Context, arg sqlcmig6.UpdateSessionStateParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) +} + var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +104,57 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) }, ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +357,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +401,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +434,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +493,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +504,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +534,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +567,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +597,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +652,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +693,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +742,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_sql.go b/session/test_sql.go index a83186069..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 0ceb0e046..c9dbc5934 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,11 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/terminal.go b/terminal.go index 7e4d552c7..5f8fa2efb 100644 --- a/terminal.go +++ b/terminal.go @@ -447,7 +447,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("could not create network directory: %v", err) } - g.stores, err = NewStores(g.cfg, clock.NewDefaultClock()) + g.stores, err = NewStores(ctx, g.cfg, clock.NewDefaultClock()) if err != nil { return fmt.Errorf("could not create stores: %v", err) }