diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index ad81179ca..a6938496f 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -80,7 +80,7 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { klog.ErrorS(err, "no tunnels available") conn.Write([]byte(fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\ncurrently no tunnels available: %v", err))) - conn.Close() + // The hijacked connection will be closed by the closeOnce defer. return } closed := make(chan struct{}) @@ -100,39 +100,60 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { agentID: backend.GetAgentID(), } t.Server.PendingDial.Add(random, connection) + + // This defer acts as a safeguard to ensure we clean up the pending dial + // if the connection is never successfully established. + established := false + defer func() { + if !established { + if t.Server.PendingDial.Remove(random) != nil { + // This metric is observed only when the frontend closes the connection. + // Other failure reasons are observed elsewhere. + metrics.Metrics.ObserveDialFailure(metrics.DialFailureFrontendClose) + } + } + }() + if err := backend.Send(dialRequest); err != nil { klog.ErrorS(err, "failed to tunnel dial request", "host", r.Host, "dialID", connection.dialID, "agentID", connection.agentID) + metrics.Metrics.ObserveDialFailure(metrics.DialFailureBackendClose) // Send proper HTTP error response conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nFailed to tunnel dial request: %v\r\n", err))) - conn.Close() - return - } - ctxt := backend.Context() - if ctxt.Err() != nil { - klog.ErrorS(ctxt.Err(), "context reports failure") - conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nBackend context error: %v\r\n", ctxt.Err()))) - conn.Close() + // The deferred cleanup will run when we return here. return } - select { - case <-ctxt.Done(): - klog.V(5).Infoln("context reports done") - default: - } + ctxt := backend.Context() select { case <-connection.connected: // Waiting for response before we begin full communication. + // The connection is successful. Mark it as established so the deferred + // cleanup function knows not to remove it from PendingDial. + established = true + // Now that connection is established, send 200 OK to switch to tunnel mode _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) if err != nil { klog.ErrorS(err, "failed to send 200 connection established", "host", r.Host, "agentID", connection.agentID) - conn.Close() + // We return here, but since `established` is true, the deferred + // function will not remove the pending dial. The agent-side goroutine + // is responsible for the established connection now. return } klog.V(3).InfoS("Connection established, sent 200 OK", "host", r.Host, "agentID", connection.agentID, "connectionID", connection.connectID) - case <-closed: // Connection was closed before being established + case <-closed: // Connection was closed by the client before being established + klog.V(2).InfoS("Frontend connection closed before being established", "host", r.Host, "dialID", connection.dialID, "agentID", connection.agentID) + // The deferred cleanup will run when we return here. + return + + case <-ctxt.Done(): // Backend connection died before being established + klog.ErrorS(ctxt.Err(), "backend context closed before connection was established", "host", r.Host, "dialID", connection.dialID, "agentID", connection.agentID) + metrics.Metrics.ObserveDialFailure(metrics.DialFailureBackendClose) + // Send proper HTTP error response + conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nBackend context error: %v\r\n", ctxt.Err()))) + // The deferred cleanup will run when we return here. + return } defer func() { @@ -148,7 +169,7 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err = backend.Send(packet); err != nil { klog.V(2).InfoS("failed to send close request packet", "host", r.Host, "agentID", connection.agentID, "connectionID", connection.connectID) } - conn.Close() + // The top-level defer handles conn.Close() }() connID := connection.connectID