11package retry
22
33import (
4+ "context"
45 "errors"
56 "testing"
7+ "time"
68)
79
10+ // timeMarginOfError represents the acceptable amount of time that may pass for
11+ // a time-based (sleep) unit before considering invalid.
12+ const timeMarginOfError = time .Millisecond
13+
814func TestRetry (t * testing.T ) {
915 action := func (attempt uint ) error {
1016 return nil
@@ -47,8 +53,99 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) {
4753 }
4854}
4955
56+ func TestRetryWithContextChecksContextAfterLastAttempt (t * testing.T ) {
57+ ctx , cancel := context .WithCancel (context .Background ())
58+
59+ strategy := func (attempt uint , sleep func (time.Duration )) bool {
60+ if attempt == 0 {
61+ return true
62+ }
63+
64+ cancel ()
65+ return false
66+ }
67+
68+ action := func (ctx context.Context , attempt uint ) error {
69+ return errors .New ("erroring" )
70+ }
71+
72+ err := RetryWithContext (ctx , action , strategy )
73+
74+ if context .Canceled != err {
75+ t .Error ("expected a context error" )
76+ }
77+ }
78+
79+ func TestRetryWithContextCancelStopsAttempts (t * testing.T ) {
80+ var numCalls int
81+
82+ ctx , cancel := context .WithCancel (context .Background ())
83+
84+ action := func (ctx context.Context , attempt uint ) error {
85+ numCalls ++
86+
87+ if numCalls == 1 {
88+ cancel ()
89+ return ctx .Err ()
90+ }
91+
92+ return nil
93+ }
94+
95+ err := RetryWithContext (ctx , action )
96+
97+ if 1 != numCalls {
98+ t .Errorf ("expected the action to be tried once, not %d times" , numCalls )
99+ }
100+
101+ if context .Canceled != err {
102+ t .Error ("expected a context error" )
103+ }
104+ }
105+
106+ func TestRetryWithContextSleepIsInterrupted (t * testing.T ) {
107+ const sleepDuration = 100 * timeMarginOfError
108+ fullySleptBy := time .Now ().Add (sleepDuration )
109+
110+ strategy := func (attempt uint , sleep func (time.Duration )) bool {
111+ if attempt > 0 {
112+ sleep (sleepDuration )
113+ }
114+ return attempt <= 1
115+ }
116+
117+ var numCalls int
118+
119+ action := func (ctx context.Context , attempt uint ) error {
120+ numCalls ++
121+ return errors .New ("erroring" )
122+ }
123+
124+ stopAfter := 10 * timeMarginOfError
125+ deadline := time .Now ().Add (stopAfter )
126+ ctx , _ := context .WithDeadline (context .Background (), deadline )
127+
128+ err := RetryWithContext (ctx , action , strategy )
129+
130+ if time .Now ().Before (deadline ) {
131+ t .Errorf ("expected to stop after %s" , stopAfter )
132+ }
133+
134+ if time .Now ().After (fullySleptBy ) {
135+ t .Errorf ("expected to stop before %s" , sleepDuration )
136+ }
137+
138+ if 1 != numCalls {
139+ t .Errorf ("expected the action to be tried once, not %d times" , numCalls )
140+ }
141+
142+ if context .DeadlineExceeded != err {
143+ t .Error ("expected a context error" )
144+ }
145+ }
146+
50147func TestShouldAttempt (t * testing.T ) {
51- shouldAttempt := shouldAttempt (1 )
148+ shouldAttempt := shouldAttempt (1 , time . Sleep )
52149
53150 if ! shouldAttempt {
54151 t .Error ("expected to return true" )
@@ -58,63 +155,63 @@ func TestShouldAttempt(t *testing.T) {
58155func TestShouldAttemptWithStrategy (t * testing.T ) {
59156 const attemptNumberShouldReturnFalse = 7
60157
61- strategy := func (attempt uint ) bool {
158+ strategy := func (attempt uint , sleep func (time. Duration ) ) bool {
62159 return (attemptNumberShouldReturnFalse != attempt )
63160 }
64161
65- should := shouldAttempt (1 , strategy )
162+ should := shouldAttempt (1 , time . Sleep , strategy )
66163
67164 if ! should {
68165 t .Error ("expected to return true" )
69166 }
70167
71- should = shouldAttempt (1 + attemptNumberShouldReturnFalse , strategy )
168+ should = shouldAttempt (1 + attemptNumberShouldReturnFalse , time . Sleep , strategy )
72169
73170 if ! should {
74171 t .Error ("expected to return true" )
75172 }
76173
77- should = shouldAttempt (attemptNumberShouldReturnFalse , strategy )
174+ should = shouldAttempt (attemptNumberShouldReturnFalse , time . Sleep , strategy )
78175
79176 if should {
80177 t .Error ("expected to return false" )
81178 }
82179}
83180
84181func TestShouldAttemptWithMultipleStrategies (t * testing.T ) {
85- trueStrategy := func (attempt uint ) bool {
182+ trueStrategy := func (attempt uint , sleep func (time. Duration ) ) bool {
86183 return true
87184 }
88185
89- falseStrategy := func (attempt uint ) bool {
186+ falseStrategy := func (attempt uint , sleep func (time. Duration ) ) bool {
90187 return false
91188 }
92189
93- should := shouldAttempt (1 , trueStrategy )
190+ should := shouldAttempt (1 , time . Sleep , trueStrategy )
94191
95192 if ! should {
96193 t .Error ("expected to return true" )
97194 }
98195
99- should = shouldAttempt (1 , falseStrategy )
196+ should = shouldAttempt (1 , time . Sleep , falseStrategy )
100197
101198 if should {
102199 t .Error ("expected to return false" )
103200 }
104201
105- should = shouldAttempt (1 , trueStrategy , trueStrategy , trueStrategy )
202+ should = shouldAttempt (1 , time . Sleep , trueStrategy , trueStrategy , trueStrategy )
106203
107204 if ! should {
108205 t .Error ("expected to return true" )
109206 }
110207
111- should = shouldAttempt (1 , falseStrategy , falseStrategy , falseStrategy )
208+ should = shouldAttempt (1 , time . Sleep , falseStrategy , falseStrategy , falseStrategy )
112209
113210 if should {
114211 t .Error ("expected to return false" )
115212 }
116213
117- should = shouldAttempt (1 , trueStrategy , trueStrategy , falseStrategy )
214+ should = shouldAttempt (1 , time . Sleep , trueStrategy , trueStrategy , falseStrategy )
118215
119216 if should {
120217 t .Error ("expected to return false" )
0 commit comments