diff --git a/gzip.go b/gzip.go index 251afc0..2b747dc 100644 --- a/gzip.go +++ b/gzip.go @@ -80,7 +80,12 @@ func Gziper(options ...Options) macaron.Handler { if err != nil { panic(err.Error()) } - defer gz.Close() + + defer func() { + if _, ok := ctx.Resp.(gzipResponseWriter); ok { + gz.Close() + } + }() gzw := gzipResponseWriter{gz, ctx.Resp} ctx.Resp = gzw @@ -94,8 +99,10 @@ func Gziper(options ...Options) macaron.Handler { ctx.Next() - // delete content length after we know we have been written to - gzw.Header().Del("Content-Length") + if _, ok := ctx.Resp.(gzipResponseWriter); ok { + // delete content length after we know we have been written to + gzw.Header().Del("Content-Length") + } } } @@ -118,3 +125,17 @@ func (grw gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } return hijacker.Hijack() } + +// DisableGzip rolls back all the changes made by gzipResponseWriter, that way the content won't be gzipped +func DisableGzip(ctx *macaron.Context) { + if grw, ok := ctx.Resp.(gzipResponseWriter); ok { + origrw := grw.ResponseWriter + ctx.MapTo(origrw, (*http.ResponseWriter)(nil)) + if _, ok := ctx.Render.(*macaron.DummyRender); !ok { + ctx.Render.SetResponseWriter(origrw) + } + ctx.Resp = origrw + ctx.Resp.Header().Del(_HEADER_CONTENT_ENCODING) + ctx.Resp.Header().Del(_HEADER_VARY) + } +} diff --git a/gzip_test.go b/gzip_test.go index 8acee28..fb5e4fe 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -108,3 +108,36 @@ func Test_ResponseWriter_Hijack(t *testing.T) { So(hijackable.Hijacked, ShouldBeTrue) }) } + +func Test_DisableGzip(t *testing.T) { + Convey("Disable compression for a request with a middleware", t, func() { + m := macaron.New() + m.Use(Gziper()) + data := "aaaaaaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbbb" + m.Get("/compressed", func() string { return data }) + m.Get("/uncompressed", DisableGzip, func() string { return data }) + + // Test compressed + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/compressed", nil) + req.Header.Set(_HEADER_ACCEPT_ENCODING, "gzip") + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + + ce := resp.Header().Get(_HEADER_CONTENT_ENCODING) + So(strings.EqualFold(ce, "gzip"), ShouldBeTrue) + So(strings.EqualFold(resp.Body.String(), data), ShouldBeFalse) + + // Test uncompressed + resp = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/uncompressed", nil) + req.Header.Set(_HEADER_ACCEPT_ENCODING, "gzip") + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + + ce = resp.Header().Get(_HEADER_CONTENT_ENCODING) + So(strings.EqualFold(ce, "gzip"), ShouldBeFalse) + So(strings.EqualFold(resp.Body.String(), data), ShouldBeTrue) + + }) +}