Skip to content

Commit 60652e5

Browse files
committed
Add RetryWithContext() and respect cancellation while sleeping
This is a breaking change for developers using custom strategies. However, there shouldn't be any impact on code using the strategies included in this package. Because the time.Sleep() call is now abstracted, strategies are tested without actually sleeping, and the strategies don't need to be aware of contexts. Context is passed through to the action in case the action is defined separately from the retry.RetryWithContext() call, is reused at multiple points, etc.
1 parent 272ad12 commit 60652e5

File tree

5 files changed

+216
-94
lines changed

5 files changed

+216
-94
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file
5555
### HTTP request with strategies and backoff
5656

5757
```go
58-
var response *http.Response
58+
action := func(ctx context.Context, attempt uint) error {
59+
var response *http.Response
5960

60-
action := func(attempt uint) error {
61-
var err error
62-
63-
response, err = http.Get("https://api.github.com/repos/Rican7/retry")
61+
req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil)
62+
if err == nil {
63+
response, err = c.Do(req)
64+
}
6465

6566
if nil == err && nil != response && response.StatusCode > 200 {
6667
err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode)
@@ -69,7 +70,8 @@ action := func(attempt uint) error {
6970
return err
7071
}
7172

72-
err := retry.Retry(
73+
err := retry.RetryWithContext(
74+
context.TODO(),
7375
action,
7476
strategy.Limit(5),
7577
strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)),

retry.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,73 @@
44
// Copyright © 2016 Trevor N. Suarez (Rican7)
55
package retry
66

7-
import "github.com/Rican7/retry/strategy"
7+
import (
8+
"context"
9+
"time"
10+
11+
"github.com/Rican7/retry/strategy"
12+
)
813

914
// Action defines a callable function that package retry can handle.
1015
type Action func(attempt uint) error
1116

17+
// ActionWithContext defines a callable function that package retry can handle.
18+
type ActionWithContext func(ctx context.Context, attempt uint) error
19+
1220
// Retry takes an action and performs it, repetitively, until successful.
1321
//
1422
// Optionally, strategies may be passed that assess whether or not an attempt
1523
// should be made.
1624
func Retry(action Action, strategies ...strategy.Strategy) error {
25+
return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...)
26+
}
27+
28+
// RetryWithContext takes an action and performs it, repetitively, until successful
29+
// or the context is done.
30+
//
31+
// Optionally, strategies may be passed that assess whether or not an attempt
32+
// should be made.
33+
//
34+
// Context errors take precedence over action errors so this commonplace test:
35+
//
36+
// err := retry.RetryWithContext(...)
37+
// if err != nil { return err }
38+
//
39+
// will pass cancellation errors up the call chain.
40+
func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error {
1741
var err error
1842

19-
for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ {
20-
err = action(attempt)
43+
for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, sleepFunc(ctx), strategies...) && nil == ctx.Err(); attempt++ {
44+
err = action(ctx, attempt)
45+
}
46+
47+
if ctx.Err() != nil {
48+
return ctx.Err()
2149
}
2250

2351
return err
2452
}
2553

2654
// shouldAttempt evaluates the provided strategies with the given attempt to
2755
// determine if the Retry loop should make another attempt.
28-
func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool {
56+
func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool {
2957
shouldAttempt := true
3058

3159
for i := 0; shouldAttempt && i < len(strategies); i++ {
32-
shouldAttempt = shouldAttempt && strategies[i](attempt)
60+
shouldAttempt = shouldAttempt && strategies[i](attempt, sleep)
3361
}
3462

3563
return shouldAttempt
3664
}
65+
66+
// sleepFunc returns a function with the same signature as time.Sleep()
67+
// that blocks for the given duration, but will return sooner if the context is
68+
// cancelled or its deadline passes.
69+
func sleepFunc(ctx context.Context) func(time.Duration) {
70+
return func(d time.Duration) {
71+
select {
72+
case <-ctx.Done():
73+
case <-time.After(d):
74+
}
75+
}
76+
}

retry_test.go

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package retry
22

33
import (
4+
"context"
45
"errors"
56
"testing"
7+
"time"
68
)
79

10+
// timeMarginOfError represents the acceptable amount of time that may pass for
11+
// a time-based (sleep) unit before considering invalid.
12+
const timeMarginOfError = time.Millisecond
13+
814
func TestRetry(t *testing.T) {
915
action := func(attempt uint) error {
1016
return nil
@@ -47,8 +53,99 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) {
4753
}
4854
}
4955

56+
func TestRetryWithContextChecksContextAfterLastAttempt(t *testing.T) {
57+
ctx, cancel := context.WithCancel(context.Background())
58+
59+
strategy := func(attempt uint, sleep func(time.Duration)) bool {
60+
if attempt == 0 {
61+
return true
62+
}
63+
64+
cancel()
65+
return false
66+
}
67+
68+
action := func(ctx context.Context, attempt uint) error {
69+
return errors.New("erroring")
70+
}
71+
72+
err := RetryWithContext(ctx, action, strategy)
73+
74+
if context.Canceled != err {
75+
t.Error("expected a context error")
76+
}
77+
}
78+
79+
func TestRetryWithContextCancelStopsAttempts(t *testing.T) {
80+
var numCalls int
81+
82+
ctx, cancel := context.WithCancel(context.Background())
83+
84+
action := func(ctx context.Context, attempt uint) error {
85+
numCalls++
86+
87+
if numCalls == 1 {
88+
cancel()
89+
return ctx.Err()
90+
}
91+
92+
return nil
93+
}
94+
95+
err := RetryWithContext(ctx, action)
96+
97+
if 1 != numCalls {
98+
t.Errorf("expected the action to be tried once, not %d times", numCalls)
99+
}
100+
101+
if context.Canceled != err {
102+
t.Error("expected a context error")
103+
}
104+
}
105+
106+
func TestRetryWithContextSleepIsInterrupted(t *testing.T) {
107+
const sleepDuration = 100 * timeMarginOfError
108+
fullySleptBy := time.Now().Add(sleepDuration)
109+
110+
strategy := func(attempt uint, sleep func(time.Duration)) bool {
111+
if attempt > 0 {
112+
sleep(sleepDuration)
113+
}
114+
return attempt <= 1
115+
}
116+
117+
var numCalls int
118+
119+
action := func(ctx context.Context, attempt uint) error {
120+
numCalls++
121+
return errors.New("erroring")
122+
}
123+
124+
stopAfter := 10 * timeMarginOfError
125+
deadline := time.Now().Add(stopAfter)
126+
ctx, _ := context.WithDeadline(context.Background(), deadline)
127+
128+
err := RetryWithContext(ctx, action, strategy)
129+
130+
if time.Now().Before(deadline) {
131+
t.Errorf("expected to stop after %s", stopAfter)
132+
}
133+
134+
if time.Now().After(fullySleptBy) {
135+
t.Errorf("expected to stop before %s", sleepDuration)
136+
}
137+
138+
if 1 != numCalls {
139+
t.Errorf("expected the action to be tried once, not %d times", numCalls)
140+
}
141+
142+
if context.DeadlineExceeded != err {
143+
t.Error("expected a context error")
144+
}
145+
}
146+
50147
func TestShouldAttempt(t *testing.T) {
51-
shouldAttempt := shouldAttempt(1)
148+
shouldAttempt := shouldAttempt(1, time.Sleep)
52149

53150
if !shouldAttempt {
54151
t.Error("expected to return true")
@@ -58,63 +155,63 @@ func TestShouldAttempt(t *testing.T) {
58155
func TestShouldAttemptWithStrategy(t *testing.T) {
59156
const attemptNumberShouldReturnFalse = 7
60157

61-
strategy := func(attempt uint) bool {
158+
strategy := func(attempt uint, sleep func(time.Duration)) bool {
62159
return (attemptNumberShouldReturnFalse != attempt)
63160
}
64161

65-
should := shouldAttempt(1, strategy)
162+
should := shouldAttempt(1, time.Sleep, strategy)
66163

67164
if !should {
68165
t.Error("expected to return true")
69166
}
70167

71-
should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy)
168+
should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy)
72169

73170
if !should {
74171
t.Error("expected to return true")
75172
}
76173

77-
should = shouldAttempt(attemptNumberShouldReturnFalse, strategy)
174+
should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy)
78175

79176
if should {
80177
t.Error("expected to return false")
81178
}
82179
}
83180

84181
func TestShouldAttemptWithMultipleStrategies(t *testing.T) {
85-
trueStrategy := func(attempt uint) bool {
182+
trueStrategy := func(attempt uint, sleep func(time.Duration)) bool {
86183
return true
87184
}
88185

89-
falseStrategy := func(attempt uint) bool {
186+
falseStrategy := func(attempt uint, sleep func(time.Duration)) bool {
90187
return false
91188
}
92189

93-
should := shouldAttempt(1, trueStrategy)
190+
should := shouldAttempt(1, time.Sleep, trueStrategy)
94191

95192
if !should {
96193
t.Error("expected to return true")
97194
}
98195

99-
should = shouldAttempt(1, falseStrategy)
196+
should = shouldAttempt(1, time.Sleep, falseStrategy)
100197

101198
if should {
102199
t.Error("expected to return false")
103200
}
104201

105-
should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy)
202+
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy)
106203

107204
if !should {
108205
t.Error("expected to return true")
109206
}
110207

111-
should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy)
208+
should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy)
112209

113210
if should {
114211
t.Error("expected to return false")
115212
}
116213

117-
should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy)
214+
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy)
118215

119216
if should {
120217
t.Error("expected to return false")

strategy/strategy.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ import (
1818
// The strategy will be passed an "attempt" number on each successive retry
1919
// iteration, starting with a `0` value before the first attempt is actually
2020
// made. This allows for a pre-action delay, etc.
21-
type Strategy func(attempt uint) bool
21+
type Strategy func(attempt uint, sleep func(time.Duration)) bool
2222

2323
// Limit creates a Strategy that limits the number of attempts that Retry will
2424
// make.
2525
func Limit(attemptLimit uint) Strategy {
26-
return func(attempt uint) bool {
27-
return (attempt <= attemptLimit)
26+
return func(attempt uint, sleep func(time.Duration)) bool {
27+
return attempt <= attemptLimit
2828
}
2929
}
3030

3131
// Delay creates a Strategy that waits the given duration before the first
3232
// attempt is made.
3333
func Delay(duration time.Duration) Strategy {
34-
return func(attempt uint) bool {
34+
return func(attempt uint, sleep func(time.Duration)) bool {
3535
if 0 == attempt {
36-
time.Sleep(duration)
36+
sleep(duration)
3737
}
3838

3939
return true
@@ -44,15 +44,15 @@ func Delay(duration time.Duration) Strategy {
4444
// the first. If the number of attempts is greater than the number of durations
4545
// provided, then the strategy uses the last duration provided.
4646
func Wait(durations ...time.Duration) Strategy {
47-
return func(attempt uint) bool {
47+
return func(attempt uint, sleep func(time.Duration)) bool {
4848
if 0 < attempt && 0 < len(durations) {
4949
durationIndex := int(attempt - 1)
5050

5151
if len(durations) <= durationIndex {
5252
durationIndex = len(durations) - 1
5353
}
5454

55-
time.Sleep(durations[durationIndex])
55+
sleep(durations[durationIndex])
5656
}
5757

5858
return true
@@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy {
6868
// BackoffWithJitter creates a Strategy that waits before each attempt, with a
6969
// duration as defined by the given backoff.Algorithm and jitter.Transformation.
7070
func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy {
71-
return func(attempt uint) bool {
71+
return func(attempt uint, sleep func(time.Duration)) bool {
7272
if 0 < attempt {
73-
time.Sleep(transformation(algorithm(attempt)))
73+
sleep(transformation(algorithm(attempt)))
7474
}
7575

7676
return true

0 commit comments

Comments
 (0)