Skip to content

Commit 2c55b7d

Browse files
committed
Add RetryWithContext() and respect cancellation while sleeping
This is a breaking change for developers using custom strategies. However, there shouldn't be any impact on code using the strategies included in this package. Because the time.Sleep() call is now abstracted, strategies are tested without actually sleeping, and the strategies don't need to be aware of contexts. Context is passed through to the action in case the action is defined separately from the retry.RetryWithContext() call, is reused at multiple points, etc.
1 parent 272ad12 commit 2c55b7d

File tree

5 files changed

+173
-94
lines changed

5 files changed

+173
-94
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file
5555
### HTTP request with strategies and backoff
5656

5757
```go
58-
var response *http.Response
58+
action := func(ctx context.Context, attempt uint) error {
59+
var response *http.Response
5960

60-
action := func(attempt uint) error {
61-
var err error
62-
63-
response, err = http.Get("https://api.github.com/repos/Rican7/retry")
61+
req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil)
62+
if err == nil {
63+
response, err = c.Do(req)
64+
}
6465

6566
if nil == err && nil != response && response.StatusCode > 200 {
6667
err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode)
@@ -69,7 +70,8 @@ action := func(attempt uint) error {
6970
return err
7071
}
7172

72-
err := retry.Retry(
73+
err := retry.RetryWithContext(
74+
context.TODO(),
7375
action,
7476
strategy.Limit(5),
7577
strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)),

retry.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,66 @@
44
// Copyright © 2016 Trevor N. Suarez (Rican7)
55
package retry
66

7-
import "github.com/Rican7/retry/strategy"
7+
import (
8+
"context"
9+
"time"
10+
11+
"github.com/Rican7/retry/strategy"
12+
)
813

914
// Action defines a callable function that package retry can handle.
1015
type Action func(attempt uint) error
1116

17+
// ActionWithContext defines a callable function that package retry can handle.
18+
type ActionWithContext func(ctx context.Context, attempt uint) error
19+
1220
// Retry takes an action and performs it, repetitively, until successful.
1321
//
1422
// Optionally, strategies may be passed that assess whether or not an attempt
1523
// should be made.
1624
func Retry(action Action, strategies ...strategy.Strategy) error {
25+
return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...)
26+
}
27+
28+
// RetryWithContext takes an action and performs it, repetitively, until successful
29+
// or the context is done.
30+
//
31+
// Optionally, strategies may be passed that assess whether or not an attempt
32+
// should be made.
33+
func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error {
1734
var err error
1835

19-
for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ {
20-
err = action(attempt)
36+
if ctx.Err() != nil {
37+
return ctx.Err()
38+
}
39+
40+
for attempt := uint(0); (0 == attempt || nil != err && nil == ctx.Err()) && shouldAttempt(attempt, sleepFunc(ctx), strategies...); attempt++ {
41+
err = action(ctx, attempt)
2142
}
2243

2344
return err
2445
}
2546

2647
// shouldAttempt evaluates the provided strategies with the given attempt to
2748
// determine if the Retry loop should make another attempt.
28-
func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool {
49+
func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool {
2950
shouldAttempt := true
3051

3152
for i := 0; shouldAttempt && i < len(strategies); i++ {
32-
shouldAttempt = shouldAttempt && strategies[i](attempt)
53+
shouldAttempt = shouldAttempt && strategies[i](attempt, sleep)
3354
}
3455

3556
return shouldAttempt
3657
}
58+
59+
// sleepFunc returns a function with the same signature as time.Sleep()
60+
// that blocks for the given duration, but will return sooner if the context is
61+
// cancelled or its deadline passes.
62+
func sleepFunc(ctx context.Context) func(time.Duration) {
63+
return func(d time.Duration) {
64+
select {
65+
case <-ctx.Done():
66+
case <-time.After(d):
67+
}
68+
}
69+
}

retry_test.go

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package retry
22

33
import (
4+
"context"
45
"errors"
56
"testing"
7+
"time"
68
)
79

10+
// timeMarginOfError represents the acceptable amount of time that may pass for
11+
// a time-based (sleep) unit before considering invalid.
12+
const timeMarginOfError = time.Millisecond
13+
814
func TestRetry(t *testing.T) {
915
action := func(attempt uint) error {
1016
return nil
@@ -47,8 +53,63 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) {
4753
}
4854
}
4955

56+
func TestRetryWithContextAlreadyCancelled(t *testing.T) {
57+
action := func(ctx context.Context, attempt uint) error {
58+
return errors.New("erroring")
59+
}
60+
61+
ctx, cancel := context.WithCancel(context.Background())
62+
cancel()
63+
64+
err := RetryWithContext(ctx, action)
65+
66+
if ctx.Err() != err {
67+
t.Error("expected a context error")
68+
}
69+
}
70+
71+
func TestRetryWithContextSleepIsInterrupted(t *testing.T) {
72+
const sleepDuration = 100 * timeMarginOfError
73+
noSleepDeadline := time.Now().Add(sleepDuration)
74+
75+
strategy := func(attempt uint, sleep func(time.Duration)) bool {
76+
sleep(sleepDuration)
77+
return true
78+
}
79+
80+
var numCalls int
81+
expectedErr := errors.New("erroring")
82+
83+
action := func(ctx context.Context, attempt uint) error {
84+
numCalls++
85+
return expectedErr
86+
}
87+
88+
stopAfter := 10 * timeMarginOfError
89+
deadline := time.Now().Add(stopAfter)
90+
ctx, _ := context.WithDeadline(context.Background(), deadline)
91+
92+
err := RetryWithContext(ctx, action, strategy)
93+
94+
if time.Now().Before(deadline) {
95+
t.Errorf("expected to stop after %s", stopAfter)
96+
}
97+
98+
if time.Now().After(noSleepDeadline) {
99+
t.Errorf("expected to stop before %s", sleepDuration)
100+
}
101+
102+
if 1 != numCalls {
103+
t.Errorf("expected the action to be tried once, not %d times", numCalls)
104+
}
105+
106+
if expectedErr != err {
107+
t.Error("expected to receive the error returned by the action")
108+
}
109+
}
110+
50111
func TestShouldAttempt(t *testing.T) {
51-
shouldAttempt := shouldAttempt(1)
112+
shouldAttempt := shouldAttempt(1, time.Sleep)
52113

53114
if !shouldAttempt {
54115
t.Error("expected to return true")
@@ -58,63 +119,63 @@ func TestShouldAttempt(t *testing.T) {
58119
func TestShouldAttemptWithStrategy(t *testing.T) {
59120
const attemptNumberShouldReturnFalse = 7
60121

61-
strategy := func(attempt uint) bool {
122+
strategy := func(attempt uint, sleep func(time.Duration)) bool {
62123
return (attemptNumberShouldReturnFalse != attempt)
63124
}
64125

65-
should := shouldAttempt(1, strategy)
126+
should := shouldAttempt(1, time.Sleep, strategy)
66127

67128
if !should {
68129
t.Error("expected to return true")
69130
}
70131

71-
should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy)
132+
should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy)
72133

73134
if !should {
74135
t.Error("expected to return true")
75136
}
76137

77-
should = shouldAttempt(attemptNumberShouldReturnFalse, strategy)
138+
should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy)
78139

79140
if should {
80141
t.Error("expected to return false")
81142
}
82143
}
83144

84145
func TestShouldAttemptWithMultipleStrategies(t *testing.T) {
85-
trueStrategy := func(attempt uint) bool {
146+
trueStrategy := func(attempt uint, sleep func(time.Duration)) bool {
86147
return true
87148
}
88149

89-
falseStrategy := func(attempt uint) bool {
150+
falseStrategy := func(attempt uint, sleep func(time.Duration)) bool {
90151
return false
91152
}
92153

93-
should := shouldAttempt(1, trueStrategy)
154+
should := shouldAttempt(1, time.Sleep, trueStrategy)
94155

95156
if !should {
96157
t.Error("expected to return true")
97158
}
98159

99-
should = shouldAttempt(1, falseStrategy)
160+
should = shouldAttempt(1, time.Sleep, falseStrategy)
100161

101162
if should {
102163
t.Error("expected to return false")
103164
}
104165

105-
should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy)
166+
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy)
106167

107168
if !should {
108169
t.Error("expected to return true")
109170
}
110171

111-
should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy)
172+
should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy)
112173

113174
if should {
114175
t.Error("expected to return false")
115176
}
116177

117-
should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy)
178+
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy)
118179

119180
if should {
120181
t.Error("expected to return false")

strategy/strategy.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ import (
1818
// The strategy will be passed an "attempt" number on each successive retry
1919
// iteration, starting with a `0` value before the first attempt is actually
2020
// made. This allows for a pre-action delay, etc.
21-
type Strategy func(attempt uint) bool
21+
type Strategy func(attempt uint, sleep func(time.Duration)) bool
2222

2323
// Limit creates a Strategy that limits the number of attempts that Retry will
2424
// make.
2525
func Limit(attemptLimit uint) Strategy {
26-
return func(attempt uint) bool {
27-
return (attempt <= attemptLimit)
26+
return func(attempt uint, sleep func(time.Duration)) bool {
27+
return attempt <= attemptLimit
2828
}
2929
}
3030

3131
// Delay creates a Strategy that waits the given duration before the first
3232
// attempt is made.
3333
func Delay(duration time.Duration) Strategy {
34-
return func(attempt uint) bool {
34+
return func(attempt uint, sleep func(time.Duration)) bool {
3535
if 0 == attempt {
36-
time.Sleep(duration)
36+
sleep(duration)
3737
}
3838

3939
return true
@@ -44,15 +44,15 @@ func Delay(duration time.Duration) Strategy {
4444
// the first. If the number of attempts is greater than the number of durations
4545
// provided, then the strategy uses the last duration provided.
4646
func Wait(durations ...time.Duration) Strategy {
47-
return func(attempt uint) bool {
47+
return func(attempt uint, sleep func(time.Duration)) bool {
4848
if 0 < attempt && 0 < len(durations) {
4949
durationIndex := int(attempt - 1)
5050

5151
if len(durations) <= durationIndex {
5252
durationIndex = len(durations) - 1
5353
}
5454

55-
time.Sleep(durations[durationIndex])
55+
sleep(durations[durationIndex])
5656
}
5757

5858
return true
@@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy {
6868
// BackoffWithJitter creates a Strategy that waits before each attempt, with a
6969
// duration as defined by the given backoff.Algorithm and jitter.Transformation.
7070
func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy {
71-
return func(attempt uint) bool {
71+
return func(attempt uint, sleep func(time.Duration)) bool {
7272
if 0 < attempt {
73-
time.Sleep(transformation(algorithm(attempt)))
73+
sleep(transformation(algorithm(attempt)))
7474
}
7575

7676
return true

0 commit comments

Comments
 (0)