@@ -163,14 +163,17 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
163163 return
164164 }
165165
166- ctx , cancel := context .WithCancel (context .Background ())
166+ ctx := context .Background ()
167+ dialCtx , dialCancel := context .WithCancel (ctx )
168+ readWriteCtx , readWriteCancel := context .WithCancel (ctx )
167169
168170 // We start a goroutine in order to be able to cancel the dialer mid-connection
169171 // on receiving a stop signal to stop the initiator.
170172 go func () {
171173 select {
172174 case <- i .stopChan :
173- cancel ()
175+ dialCancel ()
176+ readWriteCancel ()
174177 case <- ctx .Done ():
175178 return
176179 }
@@ -183,7 +186,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
183186 address := session .SocketConnectAddress [connectionAttempt % len (session .SocketConnectAddress )]
184187 session .log .OnEventf ("Connecting to: %v" , address )
185188
186- netConn , err := dialer .DialContext (ctx , "tcp" , address )
189+ netConn , err := dialer .DialContext (dialCtx , "tcp" , address )
187190 if err != nil {
188191 session .log .OnEventf ("Failed to connect: %v" , err )
189192 goto reconnect
@@ -207,24 +210,26 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
207210
208211 msgIn = make (chan fixIn )
209212 msgOut = make (chan []byte )
210- if err := session .connect (msgIn , msgOut ); err != nil {
211- session .log .OnEventf ("Failed to initiate: %v" , err )
212- goto reconnect
213- }
213+
214214
215- go readLoop (newParser (bufio .NewReader (netConn )), msgIn , session .log )
215+ go readLoop (readWriteCtx , newParser (bufio .NewReader (netConn )), msgIn , session .log )
216216 disconnected = make (chan interface {})
217217 go func () {
218- writeLoop (netConn , msgOut , session .log )
218+ writeLoop (readWriteCtx , netConn , msgOut , session .log )
219219 if err := netConn .Close (); err != nil {
220220 session .log .OnEvent (err .Error ())
221221 }
222222 close (disconnected )
223223 }()
224224
225+ if err := session .connect (msgIn , msgOut ); err != nil {
226+ session .log .OnEventf ("Failed to initiate: %v" , err )
227+ goto reconnect
228+ }
229+
225230 // This ensures we properly cleanup the goroutine and context used for
226231 // dial cancelation after successful connection.
227- cancel ()
232+ dialCancel ()
228233
229234 select {
230235 case <- disconnected :
@@ -233,7 +238,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
233238 }
234239
235240 reconnect:
236- cancel ()
241+ dialCancel ()
237242
238243 connectionAttempt ++
239244 session .log .OnEventf ("Reconnecting in %v" , session .ReconnectInterval )
0 commit comments