diff --git a/crdb/common.go b/crdb/common.go index 991f8c6..ec19b0d 100644 --- a/crdb/common.go +++ b/crdb/common.go @@ -14,7 +14,10 @@ package crdb -import "context" +import ( + "context" + "time" +) // Tx abstracts the operations needed by ExecuteInTx so that different // frameworks (e.g. go's sql package, pgx, gorm) can be used with ExecuteInTx. @@ -60,8 +63,10 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - maxRetries := numRetriesFromContext(ctx) - retryCount := 0 + // establish the retry policy + retryPolicy := getRetryPolicy(ctx) + // set up the retry policy state + retryFunc := retryPolicy.NewRetry() for { releaseFailed := false err = fn() @@ -82,13 +87,35 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - if rollbackErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); rollbackErr != nil { - return newTxnRestartError(rollbackErr, err) + // We have a retryable error. Check the retry policy. + delay, retryErr := retryFunc(err) + if delay > 0 && retryErr == nil { + // We don't want to hold locks while waiting for a backoff, so restart the entire transaction + if restartErr := tx.Exec(ctx, "ROLLBACK"); restartErr != nil { + return newTxnRestartError(restartErr, err) + } + if restartErr := tx.Exec(ctx, "BEGIN"); restartErr != nil { + return newTxnRestartError(restartErr, err) + } + if restartErr := tx.Exec(ctx, "SAVEPOINT cockroach_restart"); restartErr != nil { + return newTxnRestartError(restartErr, err) + } + } else { + if rollbackErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); rollbackErr != nil { + return newTxnRestartError(rollbackErr, err) + } } - retryCount++ - if maxRetries > 0 && retryCount > maxRetries { - return newMaxRetriesExceededError(err, maxRetries) + if retryErr != nil { + return retryErr + } + + if delay > 0 { + select { + case <-time.After(delay): + case <-ctx.Done(): + return ctx.Err() + } } } } diff --git a/crdb/retry.go b/crdb/retry.go new file mode 100644 index 0000000..9678ab6 --- /dev/null +++ b/crdb/retry.go @@ -0,0 +1,110 @@ +// Copyright 2025 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package crdb + +import ( + "time" +) + +// RetryFunc owns the state for a transaction retry operation. Usually, this is +// just the retry count. RetryFunc is not assumed to be safe for concurrent use. +type RetryFunc func(err error) (time.Duration, error) + +// RetryPolicy constructs a new instance of a RetryFunc for each transaction +// it is used with. Instances of RetryPolicy can likely be immutable and +// should be safe for concurrent calls to NewRetry. +type RetryPolicy interface { + NewRetry() RetryFunc +} + +type LimitBackoffRetryPolicy struct { + RetryLimit int + Delay time.Duration +} + +func (l *LimitBackoffRetryPolicy) NewRetry() RetryFunc { + tryCount := 0 + return func(err error) (time.Duration, error) { + tryCount++ + if tryCount > l.RetryLimit { + return 0, newMaxRetriesExceededError(err, l.RetryLimit) + } + return l.Delay, nil + } +} + +// ExpBackoffRetryPolicy implements RetryPolicy using an exponential backoff with optional +// saturation. +type ExpBackoffRetryPolicy struct { + RetryLimit int + BaseDelay time.Duration + MaxDelay time.Duration +} + +// NewRetry implements RetryPolicy +func (l *ExpBackoffRetryPolicy) NewRetry() RetryFunc { + tryCount := 0 + return func(err error) (time.Duration, error) { + tryCount++ + if tryCount > l.RetryLimit { + return 0, newMaxRetriesExceededError(err, l.RetryLimit) + } + delay := l.BaseDelay << (tryCount - 1) + if l.MaxDelay > 0 && delay > l.MaxDelay { + return l.MaxDelay, nil + } + if delay < l.BaseDelay { + // We've overflowed. + if l.MaxDelay > 0 { + return l.MaxDelay, nil + } + // There's no max delay. Giving up is probably better in + // practice than using a 290-year MAX_INT delay. + return 0, newMaxRetriesExceededError(err, tryCount) + } + return delay, nil + } +} + +// Vargo converts a go-retry style Delay provider into a RetryPolicy +func Vargo(fn func() VargoBackoff) RetryPolicy { + return &vargoAdapter{ + DelegateFactory: fn, + } +} + +// VargoBackoff allow us to adapt sethvargo/go-retry Backoff policies +// without also creating a transitive dependency on that library. +type VargoBackoff interface { + Next() (next time.Duration, stop bool) +} + +// vargoAdapter adapts backoff policies in the style of sethvargo/go-retry +type vargoAdapter struct { + DelegateFactory func() VargoBackoff +} + +func (b *vargoAdapter) NewRetry() RetryFunc { + delegate := b.DelegateFactory() + count := 0 + return func(err error) (time.Duration, error) { + count++ + d, stop := delegate.Next() + if stop { + return 0, newMaxRetriesExceededError(err, count) + } + return d, nil + } +} diff --git a/crdb/retry_test.go b/crdb/retry_test.go new file mode 100644 index 0000000..32e940a --- /dev/null +++ b/crdb/retry_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package crdb + +import ( + "testing" + "time" +) + +func assertDelays(t *testing.T, policy RetryPolicy, expectedDelays []time.Duration) { + actualDelays := make([]time.Duration, 0, len(expectedDelays)) + rf := policy.NewRetry() + for { + delay, err := rf(nil) + if err != nil { + break + } + + actualDelays = append(actualDelays, delay) + if len(actualDelays) > len(expectedDelays) { + t.Fatalf("too many retries: expected %d", len(expectedDelays)) + } + } + if len(actualDelays) != len(expectedDelays) { + t.Errorf("wrong number of retries: expected %d, got %d", len(expectedDelays), len(actualDelays)) + } + for i, delay := range actualDelays { + expected := expectedDelays[i] + if delay != expected { + t.Errorf("wrong delay at index %d: expected %d, got %d", i, expected, delay) + } + } +} + +func TestLimitBackoffRetryPolicy(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: 3, + Delay: 1 * time.Second, + } + assertDelays(t, policy, []time.Duration{ + 1 * time.Second, + 1 * time.Second, + 1 * time.Second, + }) +} + +func TestExpBackoffRetryPolicy(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 5, + BaseDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + } + assertDelays(t, policy, []time.Duration{ + 1 * time.Second, + 2 * time.Second, + 4 * time.Second, + 5 * time.Second, + 5 * time.Second, + }) +} diff --git a/crdb/tx.go b/crdb/tx.go index 6e5f2d6..3bc275f 100644 --- a/crdb/tx.go +++ b/crdb/tx.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" "errors" + "time" ) // Execute runs fn and retries it as needed. It is used to add retry handling to @@ -48,20 +49,20 @@ import ( // following snippet, the original retryable error will be masked by the call to // fmt.Errorf, and the transaction will not be automatically retried. // -// crdb.Execute(func () error { -// rows, err := db.QueryContext(ctx, "SELECT ...") -// if err != nil { -// return fmt.Errorf("scanning row: %s", err) -// } -// defer rows.Close() -// for rows.Next() { -// // ... -// } -// if err := rows.Err(); err != nil { -// return fmt.Errorf("scanning row: %s", err) -// } -// return nil -// }) +// crdb.Execute(func () error { +// rows, err := db.QueryContext(ctx, "SELECT ...") +// if err != nil { +// return fmt.Errorf("scanning row: %s", err) +// } +// defer rows.Close() +// for rows.Next() { +// // ... +// } +// if err := rows.Err(); err != nil { +// return fmt.Errorf("scanning row: %s", err) +// } +// return nil +// }) // // Instead, add context by returning an error that implements either: // - a `Cause() error` method, in the manner of github.com/pkg/errors, or @@ -74,23 +75,22 @@ import ( // 1.13's special `%w` formatter with fmt.Errorf(), for example // fmt.Errorf("scanning row: %w", err). // -// import "github.com/pkg/errors" -// -// crdb.Execute(func () error { -// rows, err := db.QueryContext(ctx, "SELECT ...") -// if err != nil { -// return errors.Wrap(err, "scanning row") -// } -// defer rows.Close() -// for rows.Next() { -// // ... -// } -// if err := rows.Err(); err != nil { -// return errors.Wrap(err, "scanning row") -// } -// return nil -// }) +// import "github.com/pkg/errors" // +// crdb.Execute(func () error { +// rows, err := db.QueryContext(ctx, "SELECT ...") +// if err != nil { +// return errors.Wrap(err, "scanning row") +// } +// defer rows.Close() +// for rows.Next() { +// // ... +// } +// if err := rows.Err(); err != nil { +// return errors.Wrap(err, "scanning row") +// } +// return nil +// }) func Execute(fn func() error) (err error) { for { err = fn() @@ -105,7 +105,7 @@ func Execute(fn func() error) (err error) { // operations with configurable parameters. type ExecuteCtxFunc func(context.Context, ...interface{}) error -// ExecuteCtx runs fn and retries it as needed, respecting a maximum retry count +// ExecuteCtx runs fn and retries it as needed, respecting a retry policy // obtained from the context. It is used to add configurable retry handling to // the execution of a single statement. If a multi-statement transaction is // being run, use ExecuteTx instead. @@ -116,6 +116,8 @@ type ExecuteCtxFunc func(context.Context, ...interface{}) error // returns a max retries exceeded error wrapping the last retryable error // encountered. // +// Arbitrary retry policies can be configured using WithRetryPolicy(ctx, p). +// // The fn parameter accepts variadic arguments which are passed through on each // retry attempt, allowing for flexible parameterization of the retried operation. // @@ -143,8 +145,11 @@ type ExecuteCtxFunc func(context.Context, ...interface{}) error // return nil // }, userID) func ExecuteCtx(ctx context.Context, fn ExecuteCtxFunc, args ...interface{}) (err error) { - maxRetries := numRetriesFromContext(ctx) - for n := 0; n <= maxRetries; n++ { + // establish the retry policy + retryPolicy := getRetryPolicy(ctx) + // set up the retry policy state + retryFunc := retryPolicy.NewRetry() + for { if err = ctx.Err(); err != nil { return err } @@ -153,29 +158,49 @@ func ExecuteCtx(ctx context.Context, fn ExecuteCtxFunc, args ...interface{}) (er if err == nil || !errIsRetryable(err) { return err } + delay, retryErr := retryFunc(err) + if retryErr != nil { + return retryErr + } + if delay > 0 { + select { + case <-time.After(delay): + case <-ctx.Done(): + return ctx.Err() + } + } } - - return newMaxRetriesExceededError(err, maxRetries) } type txConfigKey struct{} // WithMaxRetries configures context so that ExecuteTx retries tx specified // number of times when encountering retryable errors. -// Setting retries to 0 will retry indefinitely. +// Setting retries to 0 will not retry: the transaction will be tried only once. func WithMaxRetries(ctx context.Context, retries int) context.Context { - return context.WithValue(ctx, txConfigKey{}, retries) + p := &LimitBackoffRetryPolicy{retries, 0} + return WithRetryPolicy(ctx, p) } -const defaultRetries = 50 +// WithRetryPolicy uses an arbitrary retry policy to perform retries. +func WithRetryPolicy(ctx context.Context, policy RetryPolicy) context.Context { + return context.WithValue(ctx, txConfigKey{}, policy) +} -func numRetriesFromContext(ctx context.Context) int { +// getRetryPolicy retrieves the RetryPolicy from the context or the default +func getRetryPolicy(ctx context.Context) RetryPolicy { + retryPolicy := defaultRetryPolicy if v := ctx.Value(txConfigKey{}); v != nil { - if retries, ok := v.(int); ok && retries >= 0 { - return retries - } + retryPolicy = v.(RetryPolicy) } - return defaultRetries + + return retryPolicy +} + +const defaultRetries = 50 + +var defaultRetryPolicy RetryPolicy = &LimitBackoffRetryPolicy{ + RetryLimit: defaultRetries, } // ExecuteTx runs fn inside a transaction and retries it as needed. On @@ -201,12 +226,12 @@ func numRetriesFromContext(ctx context.Context) int { // following snippet, the original retryable error will be masked by the call to // fmt.Errorf, and the transaction will not be automatically retried. // -// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { -// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { -// return fmt.Errorf("updating record: %s", err) -// } -// return nil -// }) +// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { +// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { +// return fmt.Errorf("updating record: %s", err) +// } +// return nil +// }) // // Instead, add context by returning an error that implements either: // - a `Cause() error` method, in the manner of github.com/pkg/errors, or @@ -219,15 +244,14 @@ func numRetriesFromContext(ctx context.Context) int { // 1.13's special `%w` formatter with fmt.Errorf(), for example // fmt.Errorf("scanning row: %w", err). // -// import "github.com/pkg/errors" -// -// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { -// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { -// return errors.Wrap(err, "updating record") -// } -// return nil -// }) +// import "github.com/pkg/errors" // +// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { +// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { +// return errors.Wrap(err, "updating record") +// } +// return nil +// }) func ExecuteTx(ctx context.Context, db *sql.DB, opts *sql.TxOptions, fn func(*sql.Tx) error) error { // Start a transaction. tx, err := db.BeginTx(ctx, opts) @@ -254,7 +278,7 @@ func (tx stdlibTxnAdapter) Commit(context.Context) error { return tx.tx.Commit() } -// Commit is part of the tx interface. +// Rollback is part of the tx interface. func (tx stdlibTxnAdapter) Rollback(context.Context) error { return tx.tx.Rollback() } diff --git a/crdb/tx_test.go b/crdb/tx_test.go index 5713be1..1a25e91 100644 --- a/crdb/tx_test.go +++ b/crdb/tx_test.go @@ -101,13 +101,48 @@ func TestExecuteTx(t *testing.T) { // TestConfigureRetries verifies that the number of retries can be specified // via context. func TestConfigureRetries(t *testing.T) { - ctx := context.Background() - if numRetriesFromContext(ctx) != defaultRetries { - t.Fatal("expect default number of retries") - } + ctx := WithMaxRetries(context.Background(), 0) + requireRetries(t, ctx, 0) + + ctx = WithMaxRetries(context.Background(), 1) + requireRetries(t, ctx, 1) + + ctx = context.Background() + requireRetries(t, ctx, defaultRetries) + ctx = WithMaxRetries(context.Background(), 123+defaultRetries) - if numRetriesFromContext(ctx) != defaultRetries+123 { - t.Fatal("expected default+123 retires") + requireRetries(t, ctx, 123+defaultRetries) + + ctx = WithRetryPolicy(context.Background(), &ExpBackoffRetryPolicy{ + RetryLimit: 10, + BaseDelay: 10, + MaxDelay: 1000, + }) + requireRetries(t, ctx, 10) +} + +func requireRetries(t *testing.T, ctx context.Context, numRetries int) { + p := getRetryPolicy(ctx) + if p == nil { + t.Fatal("expected non-nil retry policy") + } + + rf := p.NewRetry() + tryCount := 0 + for { + // we try + tryCount++ + + // Then, decide whether we're out of retries. + // The first try is not a retry, so we should + _, err := rf(nil) + if err != nil { + retryCount := tryCount - 1 + if retryCount != numRetries { + t.Fatalf("expected %d retries, got %d", numRetries, retryCount) + } + return + } } } diff --git a/testserver/version/version.go b/testserver/version/version.go index bc84d7f..11231d4 100644 --- a/testserver/version/version.go +++ b/testserver/version/version.go @@ -58,7 +58,8 @@ func (v *Version) Metadata() string { } // String returns the string representation, in the format: -// "v1.2.3-beta+md" +// +// "v1.2.3-beta+md" func (v Version) String() string { var b bytes.Buffer fmt.Fprintf(&b, "v%d.%d.%d", v.major, v.minor, v.patch) @@ -84,7 +85,9 @@ var numericRE = regexp.MustCompile(`^(0|[1-9][0-9]*)$`) // Parse creates a version from a string. The string must be a valid semantic // version (as per https://semver.org/spec/v2.0.0.html) in the format: -// "vMINOR.MAJOR.PATCH[-PRERELEASE][+METADATA]". +// +// "vMINOR.MAJOR.PATCH[-PRERELEASE][+METADATA]". +// // MINOR, MAJOR, and PATCH are numeric values (without any leading 0s). // PRERELEASE and METADATA can contain ASCII characters and digits, hyphens and // dots.