@@ -514,23 +514,28 @@ func (a *HostAgent) startRoutinesAndWait(ctx context.Context, errCh <-chan error
514514 return a .driver .Stop (ctx )
515515}
516516
517- func (a * HostAgent ) Info (_ context.Context ) (* hostagentapi.Info , error ) {
517+ // GuestIP returns the guest IP address, or nil if not accessible by direct IP.
518+ func (a * HostAgent ) GuestIP () net.IP {
518519 a .guestIPMu .RLock ()
519520 defer a .guestIPMu .RUnlock ()
521+ return a .guestIP
522+ }
523+
524+ func (a * HostAgent ) Info (_ context.Context ) (* hostagentapi.Info , error ) {
525+ guestIP := a .GuestIP ()
520526 info := & hostagentapi.Info {
521- GuestIP : a . guestIP ,
527+ GuestIP : guestIP ,
522528 SSHLocalPort : a .sshLocalPort ,
523529 }
524530 return info , nil
525531}
526532
527533func (a * HostAgent ) sshAddressPort () (sshAddress string , sshPort int ) {
528- a .guestIPMu .RLock ()
529- defer a .guestIPMu .RUnlock ()
530534 sshAddress = a .instSSHAddress
531535 sshPort = a .sshLocalPort
532- if a .guestIP != nil {
533- sshAddress = a .guestIP .String ()
536+ guestIP := a .GuestIP ()
537+ if guestIP != nil {
538+ sshAddress = guestIP .String ()
534539 sshPort = 22
535540 logrus .Debugf ("Using the guest IP address %q directly" , sshAddress )
536541 }
@@ -891,9 +896,17 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
891896 if useSSHFwd {
892897 a .portForwarder .OnEvent (ctx , ev )
893898 } else {
894- dialContext := a .DialContextToGuestIP ()
895- if dialContext == nil {
896- dialContext = portfwd .DialContextToGRPCTunnel (client )
899+ dialContext := func (ctx context.Context , network , address string ) (net.Conn , error ) {
900+ dialContext := a .DialContextToGuestIP ()
901+ if dialContext == nil {
902+ return portfwd .DialContextToGRPCTunnel (client )(ctx , network , address )
903+ }
904+ if host , _ , err := net .SplitHostPort (address ); err != nil {
905+ return nil , err
906+ } else if ip := net .ParseIP (host ); ip .IsUnspecified () || ip .Equal (a .GuestIP ()) {
907+ return a .DialContextToGuestIP ()(ctx , network , address )
908+ }
909+ return portfwd .DialContextToGRPCTunnel (client )(ctx , network , address )
897910 }
898911 a .grpcPortForwarder .OnEvent (ctx , dialContext , ev )
899912 }
@@ -911,9 +924,8 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
911924// DialContextToGuestIP returns a DialContext function that connects to the guest IP directly.
912925// If the guest IP is not known, it returns nil.
913926func (a * HostAgent ) DialContextToGuestIP () func (ctx context.Context , network , address string ) (net.Conn , error ) {
914- a .guestIPMu .RLock ()
915- defer a .guestIPMu .RUnlock ()
916- if a .guestIP == nil {
927+ guestIP := a .GuestIP ()
928+ if guestIP == nil {
917929 return nil
918930 }
919931 return func (ctx context.Context , network , address string ) (net.Conn , error ) {
@@ -924,7 +936,7 @@ func (a *HostAgent) DialContextToGuestIP() func(ctx context.Context, network, ad
924936 }
925937 // Host part of address is ignored, because it already has been checked by forwarding rules
926938 // and we want to connect to the guest IP directly.
927- return d .DialContext (ctx , network , net .JoinHostPort (a . guestIP .String (), port ))
939+ return d .DialContext (ctx , network , net .JoinHostPort (guestIP .String (), port ))
928940 }
929941}
930942
0 commit comments