Skip to content

Commit 9fd03ae

Browse files
committed
pkg/hostagent: Use HostAgent.DialContextToGuestIP() if the IP is accessible directly.
Signed-off-by: Norio Nomura <[email protected]>
1 parent 6e0e59a commit 9fd03ae

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

pkg/hostagent/hostagent.go

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -514,23 +514,28 @@ func (a *HostAgent) startRoutinesAndWait(ctx context.Context, errCh <-chan error
514514
return a.driver.Stop(ctx)
515515
}
516516

517-
func (a *HostAgent) Info(_ context.Context) (*hostagentapi.Info, error) {
517+
// GuestIP returns the guest IP address, or nil if not accessible by direct IP.
518+
func (a *HostAgent) GuestIP() net.IP {
518519
a.guestIPMu.RLock()
519520
defer a.guestIPMu.RUnlock()
521+
return a.guestIP
522+
}
523+
524+
func (a *HostAgent) Info(_ context.Context) (*hostagentapi.Info, error) {
525+
guestIP := a.GuestIP()
520526
info := &hostagentapi.Info{
521-
GuestIP: a.guestIP,
527+
GuestIP: guestIP,
522528
SSHLocalPort: a.sshLocalPort,
523529
}
524530
return info, nil
525531
}
526532

527533
func (a *HostAgent) sshAddressPort() (sshAddress string, sshPort int) {
528-
a.guestIPMu.RLock()
529-
defer a.guestIPMu.RUnlock()
530534
sshAddress = a.instSSHAddress
531535
sshPort = a.sshLocalPort
532-
if a.guestIP != nil {
533-
sshAddress = a.guestIP.String()
536+
guestIP := a.GuestIP()
537+
if guestIP != nil {
538+
sshAddress = guestIP.String()
534539
sshPort = 22
535540
logrus.Debugf("Using the guest IP address %q directly", sshAddress)
536541
}
@@ -891,9 +896,17 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
891896
if useSSHFwd {
892897
a.portForwarder.OnEvent(ctx, ev)
893898
} else {
894-
dialContext := a.DialContextToGuestIP()
895-
if dialContext == nil {
896-
dialContext = portfwd.DialContextToGRPCTunnel(client)
899+
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
900+
dialContext := a.DialContextToGuestIP()
901+
if dialContext == nil {
902+
return portfwd.DialContextToGRPCTunnel(client)(ctx, network, address)
903+
}
904+
if host, _, err := net.SplitHostPort(address); err != nil {
905+
return nil, err
906+
} else if ip := net.ParseIP(host); ip.IsUnspecified() || ip.Equal(a.GuestIP()) {
907+
return a.DialContextToGuestIP()(ctx, network, address)
908+
}
909+
return portfwd.DialContextToGRPCTunnel(client)(ctx, network, address)
897910
}
898911
a.grpcPortForwarder.OnEvent(ctx, dialContext, ev)
899912
}
@@ -911,9 +924,8 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
911924
// DialContextToGuestIP returns a DialContext function that connects to the guest IP directly.
912925
// If the guest IP is not known, it returns nil.
913926
func (a *HostAgent) DialContextToGuestIP() func(ctx context.Context, network, address string) (net.Conn, error) {
914-
a.guestIPMu.RLock()
915-
defer a.guestIPMu.RUnlock()
916-
if a.guestIP == nil {
927+
guestIP := a.GuestIP()
928+
if guestIP == nil {
917929
return nil
918930
}
919931
return func(ctx context.Context, network, address string) (net.Conn, error) {
@@ -924,7 +936,7 @@ func (a *HostAgent) DialContextToGuestIP() func(ctx context.Context, network, ad
924936
}
925937
// Host part of address is ignored, because it already has been checked by forwarding rules
926938
// and we want to connect to the guest IP directly.
927-
return d.DialContext(ctx, network, net.JoinHostPort(a.guestIP.String(), port))
939+
return d.DialContext(ctx, network, net.JoinHostPort(guestIP.String(), port))
928940
}
929941
}
930942

0 commit comments

Comments
 (0)