@@ -3,51 +3,123 @@ package fn
33import (
44 "context"
55 "sync"
6+ "sync/atomic"
67)
78
89// GoroutineManager is used to launch goroutines until context expires or the
910// manager is stopped. The Stop method blocks until all started goroutines stop.
1011type GoroutineManager struct {
11- wg sync.WaitGroup
12- mu sync.Mutex
13- ctx context.Context
14- cancel func ()
12+ // id is used to generate unique ids for each goroutine.
13+ id atomic.Uint32
14+
15+ // cancelFns is a map of cancel functions that can be used to cancel the
16+ // context of a goroutine. The mutex must be held when accessing this
17+ // map. The key is the id of the goroutine.
18+ cancelFns map [uint32 ]context.CancelFunc
19+
20+ mu sync.Mutex
21+
22+ stopped sync.Once
23+ quit chan struct {}
24+ wg sync.WaitGroup
1525}
1626
1727// NewGoroutineManager constructs and returns a new instance of
1828// GoroutineManager.
19- func NewGoroutineManager (ctx context.Context ) * GoroutineManager {
20- ctx , cancel := context .WithCancel (ctx )
21-
29+ func NewGoroutineManager () * GoroutineManager {
2230 return & GoroutineManager {
23- ctx : ctx ,
24- cancel : cancel ,
31+ cancelFns : make (map [uint32 ]context.CancelFunc ),
32+ quit : make (chan struct {}),
33+ }
34+ }
35+
36+ // addCancelFn adds a context cancel function to the manager and returns an id
37+ // that can can be used to cancel the context later on when the goroutine is
38+ // done.
39+ func (g * GoroutineManager ) addCancelFn (cancel context.CancelFunc ) uint32 {
40+ g .mu .Lock ()
41+ defer g .mu .Unlock ()
42+
43+ id := g .id .Add (1 )
44+ g .cancelFns [id ] = cancel
45+
46+ return id
47+ }
48+
49+ // cancel cancels the context associated with the passed id.
50+ func (g * GoroutineManager ) cancel (id uint32 ) {
51+ g .mu .Lock ()
52+ defer g .mu .Unlock ()
53+
54+ g .cancelUnsafe (id )
55+ }
56+
57+ // cancelUnsafe cancels the context associated with the passed id without
58+ // acquiring the mutex.
59+ func (g * GoroutineManager ) cancelUnsafe (id uint32 ) {
60+ fn , ok := g .cancelFns [id ]
61+ if ! ok {
62+ return
2563 }
64+
65+ fn ()
66+
67+ delete (g .cancelFns , id )
2668}
2769
2870// Go tries to start a new goroutine and returns a boolean indicating its
29- // success. It fails iff the goroutine manager is stopping or its context passed
30- // to NewGoroutineManager has expired.
31- func (g * GoroutineManager ) Go (f func (ctx context.Context )) bool {
32- // Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race
33- // condition, since it is not clear should Wait() block or not. This
71+ // success. It returns true if the goroutine was successfully created and false
72+ // otherwise. A goroutine will fail to be created iff the goroutine manager is
73+ // stopping or the passed context has already expired. The passed call-back
74+ // function must exit if the passed context expires.
75+ func (g * GoroutineManager ) Go (ctx context.Context ,
76+ f func (ctx context.Context )) bool {
77+
78+ // Derive a cancellable context from the passed context and store its
79+ // cancel function in the manager. The context will be cancelled when
80+ // either the parent context is cancelled or the quit channel is closed
81+ // which will call the stored cancel function.
82+ ctx , cancel := context .WithCancel (ctx )
83+ id := g .addCancelFn (cancel )
84+
85+ // Calling wg.Add(1) and wg.Wait() when the wg's counter is 0 is a race
86+ // condition, since it is not clear if Wait() should block or not. This
3487 // kind of race condition is detected by Go runtime and results in a
35- // crash if running with `-race`. To prevent this, whole Go method is
36- // protected with a mutex. The call to wg.Wait() inside Stop() can still
37- // run in parallel with Go, but in that case g.ctx is in expired state,
38- // because cancel() was called in Stop, so Go returns before wg.Add(1)
39- // call.
88+ // crash if running with `-race`. To prevent this, we protect the calls
89+ // to wg.Add(1) and wg.Wait() with a mutex. If we block here because
90+ // Stop is running first, then Stop will close the quit channel which
91+ // will cause the context to be cancelled, and we will exit before
92+ // calling wg.Add(1). If we grab the mutex here before Stop does, then
93+ // Stop will block until after we call wg.Add(1).
4094 g .mu .Lock ()
4195 defer g .mu .Unlock ()
4296
43- if g .ctx .Err () != nil {
97+ // Before continuing to start the goroutine, we need to check if the
98+ // context has already expired. This could be the case if the parent
99+ // context has already expired or if Stop has been called.
100+ if ctx .Err () != nil {
101+ g .cancelUnsafe (id )
102+
103+ return false
104+ }
105+
106+ // Ensure that the goroutine is not started if the manager has stopped.
107+ select {
108+ case <- g .quit :
109+ g .cancelUnsafe (id )
110+
44111 return false
112+ default :
45113 }
46114
47115 g .wg .Add (1 )
48116 go func () {
49- defer g .wg .Done ()
50- f (g .ctx )
117+ defer func () {
118+ g .cancel (id )
119+ g .wg .Done ()
120+ }()
121+
122+ f (ctx )
51123 }()
52124
53125 return true
@@ -56,20 +128,30 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
56128// Stop prevents new goroutines from being added and waits for all running
57129// goroutines to finish.
58130func (g * GoroutineManager ) Stop () {
59- g .mu .Lock ()
60- g .cancel ()
61- g .mu .Unlock ()
62-
63- // Wait for all goroutines to finish. Note that this wg.Wait() call is
64- // safe, since it can't run in parallel with wg.Add(1) call in Go, since
65- // we just cancelled the context and even if Go call starts running here
66- // after acquiring the mutex, it would see that the context has expired
67- // and return false instead of calling wg.Add(1).
68- g .wg .Wait ()
131+ g .stopped .Do (func () {
132+ // Closing the quit channel will prevent any new goroutines from
133+ // starting.
134+ g .mu .Lock ()
135+ close (g .quit )
136+ for _ , cancel := range g .cancelFns {
137+ cancel ()
138+ }
139+ g .mu .Unlock ()
140+
141+ // Wait for all goroutines to finish. Note that this wg.Wait()
142+ // call is safe, since it can't run in parallel with wg.Add(1)
143+ // call in Go, since we just cancelled the context and even if
144+ // Go call starts running here after acquiring the mutex, it
145+ // would see that the context has expired and return false
146+ // instead of calling wg.Add(1).
147+ g .wg .Wait ()
148+ })
69149}
70150
71- // Done returns a channel which is closed when either the context passed to
72- // NewGoroutineManager expires or when Stop is called.
151+ // Done returns a channel which is closed once Stop has been called and the
152+ // quit channel closed. Note that the channel closing indicates that shutdown
153+ // of the GoroutineManager has started but not necessarily that the Stop method
154+ // has finished.
73155func (g * GoroutineManager ) Done () <- chan struct {} {
74- return g .ctx . Done ()
156+ return g .quit
75157}
0 commit comments