Skip to content

Commit 0fa3f99

Browse files
authored
Merge pull request #20 from jlhawn/copy_response
Copy response to client on failed handshake
2 parents 5fdfb40 + 53b8c5c commit 0fa3f99

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

websocketproxy.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package websocketproxy
33

44
import (
55
"fmt"
6+
"io"
67
"log"
78
"net"
89
"net/http"
@@ -133,7 +134,17 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
133134
// http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01
134135
connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader)
135136
if err != nil {
136-
log.Printf("websocketproxy: couldn't dial to remote backend url %s\n", err)
137+
log.Printf("websocketproxy: couldn't dial to remote backend url %s", err)
138+
if resp != nil {
139+
// If the WebSocket handshake fails, ErrBadHandshake is returned
140+
// along with a non-nil *http.Response so that callers can handle
141+
// redirects, authentication, etcetera.
142+
if err := copyResponse(rw, resp); err != nil {
143+
log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err)
144+
}
145+
} else {
146+
http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
147+
}
137148
return
138149
}
139150
defer connBackend.Close()
@@ -156,7 +167,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
156167
// Also pass the header that we gathered from the Dial handshake.
157168
connPub, err := upgrader.Upgrade(rw, req, upgradeHeader)
158169
if err != nil {
159-
log.Printf("websocketproxy: couldn't upgrade %s\n", err)
170+
log.Printf("websocketproxy: couldn't upgrade %s", err)
160171
return
161172
}
162173
defer connPub.Close()
@@ -200,3 +211,20 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
200211
log.Printf(message, err)
201212
}
202213
}
214+
215+
func copyHeader(dst, src http.Header) {
216+
for k, vv := range src {
217+
for _, v := range vv {
218+
dst.Add(k, v)
219+
}
220+
}
221+
}
222+
223+
func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
224+
copyHeader(rw.Header(), resp.Header)
225+
rw.WriteHeader(resp.StatusCode)
226+
defer resp.Body.Close()
227+
228+
_, err := io.Copy(rw, resp.Body)
229+
return err
230+
}

0 commit comments

Comments
 (0)