diff --git a/client.go b/client.go index 779cd5e..23b1f1e 100644 --- a/client.go +++ b/client.go @@ -149,6 +149,9 @@ type Config struct { // hung connections. DisableEPSV bool + // Overrides the net.Dial method to allow for customised network connections. + DialFunc func(network string, address string) (net.Conn, error) + // For testing convenience. stubResponses map[string]stubResponse } @@ -369,15 +372,20 @@ func (c *Client) openConn(idx int, host string) (pconn *persistentConn, err erro var conn net.Conn - if c.config.TLSConfig != nil && c.config.TLSMode == TLSImplicit { - pconn.debug("opening TLS control connection to %s", host) - dialer := &net.Dialer{ - Timeout: c.config.Timeout, - } - conn, err = tls.DialWithDialer(dialer, "tcp", host, pconn.config.TLSConfig) + if c.config.DialFunc != nil { + pconn.debug("opening control connection to %s via DialFunc", host) + conn, err = c.config.DialFunc("tcp", host) } else { - pconn.debug("opening control connection to %s", host) - conn, err = net.DialTimeout("tcp", host, c.config.Timeout) + if c.config.TLSConfig != nil && c.config.TLSMode == TLSImplicit { + pconn.debug("opening TLS control connection to %s", host) + dialer := &net.Dialer{ + Timeout: c.config.Timeout, + } + conn, err = tls.DialWithDialer(dialer, "tcp", host, pconn.config.TLSConfig) + } else { + pconn.debug("opening control connection to %s", host) + conn, err = net.DialTimeout("tcp", host, c.config.Timeout) + } } var ( diff --git a/persistent_connection.go b/persistent_connection.go index 15f5d74..9638cc8 100644 --- a/persistent_connection.go +++ b/persistent_connection.go @@ -418,8 +418,16 @@ func (pconn *persistentConn) prepareDataConn() (func() (net.Conn, error), error) return nil, err } - pconn.debug("opening data connection to %s", host) - dc, netErr := net.DialTimeout("tcp", host, pconn.config.Timeout) + var dc net.Conn + var netErr error + + if pconn.config.DialFunc != nil { + pconn.debug("opening data connection to %s via DialFunc", host) + dc, netErr = pconn.config.DialFunc("tcp", host) + } else { + pconn.debug("opening data connection to %s", host) + dc, netErr = net.DialTimeout("tcp", host, pconn.config.Timeout) + } if netErr != nil { var isTemporary bool