Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ import (
// wildcards (path variables).
type Handle func(http.ResponseWriter, *http.Request, Params)

// Middleware is a function that accepts a httprouter.Handle function as input
// and returns another httprouter.Handle function. These are used to
// wrap the httprouter.Handler function to implement cross cutting functionalities like
// authentication, logging etc..
//
// httprouter.Router exposes the function Use() to add Middleware functions to it.
type Middleware func(Handle) Handle

// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Expand Down Expand Up @@ -141,6 +149,8 @@ type Router struct {
paramsPool sync.Pool
maxParams uint16

middlewares []Middleware

// If enabled, adds the matched route path onto the http.Request context
// before invoking the handler.
// The matched route path is only added to handlers of routes that were
Expand Down Expand Up @@ -302,6 +312,8 @@ func (r *Router) Handle(method, path string, handle Handle) {
panic("handle must not be nil")
}

handle = r.wrapMiddlewaresAroundHandler(handle)

if r.SaveMatchedRoutePath {
varsCount++
handle = r.saveMatchedRoutePath(path, handle)
Expand Down Expand Up @@ -538,3 +550,28 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.NotFound(w, req)
}
}

// Function to add middleware of type httprouter.Middleware to the Router.
// Middlewares are to be added to the router before defining the routes.
//
// The middlewares will then wrap the request handlers and be run in the order
// they were added to the router.
func (r *Router) Use (mw Middleware) {
// Lazy initialization of the middlewares slice
if r.middlewares == nil {
r.middlewares = make([]Middleware, 0, 1)
}
r.middlewares = append(r.middlewares, mw)
}

// This function sandwiches the specified handler between the middlewares.
// The middlewares wrapping will happen in the reverse order of
// the middleware addition so that the execution of the middlewares will
// keep their order.
func (r *Router) wrapMiddlewaresAroundHandler (h Handle) Handle {
middlewareSize := len(r.middlewares)
for i := middlewareSize - 1; i >= 0; i-- {
h = r.middlewares[i](h)
}
return h
}
211 changes: 211 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,214 @@ func TestRouterServeFiles(t *testing.T) {
t.Error("serving file failed")
}
}

func TestNoMiddleware(t *testing.T) {
router := New()
router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) {})

w := new(mockResponseWriter)
req, _ := http.NewRequest(http.MethodGet, "/test", nil)

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

router.ServeHTTP(w, req)

if router.middlewares != nil {
t.Fatal("middleware array initialized even before adding a middleware")
}
}

func TestMiddlewareWrapsTheRequestHandler(t *testing.T) {
router := New()
middlewareInvoked := false
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
middlewareInvoked = true
h(w, r, p)
}
})
requestHandlerInvoked := false
router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
if !middlewareInvoked {
panic("middleware did not get invoked")
}
requestHandlerInvoked = true
})

w := new(mockResponseWriter)
req, _ := http.NewRequest(http.MethodGet, "/test", nil)

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

router.ServeHTTP(w, req)

if !middlewareInvoked {
t.Fatal("middleware did not get invoked")
}
if !requestHandlerInvoked {
t.Fatal("request handler did not get invoked")
}
}

func TestMiddlewareAndMultiplePaths(t *testing.T) {
router := New()
middlewareInvokeCount := 0
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
middlewareInvokeCount++
h(w, r, p)
}
})

router.Handle(http.MethodGet, "/test1", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
if middlewareInvokeCount != 1 {
panic("middleware did not get invoked for test1")
}
})
router.Handle(http.MethodGet, "/test2", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
if middlewareInvokeCount != 2 {
panic("middleware did not get invoked for test2")
}
})

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

w1 := new(mockResponseWriter)
req1, _ := http.NewRequest(http.MethodGet, "/test1", nil)
router.ServeHTTP(w1, req1)

w2 := new(mockResponseWriter)
req2, _ := http.NewRequest(http.MethodGet, "/test2", nil)
router.ServeHTTP(w2, req2)
}

func TestMiddlewareInvokeOrder(t *testing.T) {
router := New()
middlewareInvokeCount := 0
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
middlewareInvokeCount++
if middlewareInvokeCount != 1 {
panic("first middleware was not invoked first")
}
h(w, r, p)
}
})
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
middlewareInvokeCount++
if middlewareInvokeCount != 2 {
panic("second middleware was not invoked second")
}
h(w, r, p)
}
})

router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
if middlewareInvokeCount != 2 {
panic("both middlewares did not get invoked")
}
})

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

w := new(mockResponseWriter)
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
router.ServeHTTP(w, req)
}

func TestMiddlewareInvokeOrderAfterHandlerDone(t *testing.T) {
router := New()
middlewareInvokeCount := 0
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
h(w, r, p)
middlewareInvokeCount++
if middlewareInvokeCount != 2 {
panic("first middleware was not invoked in reverse")
}
}
})
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
h(w, r, p)
middlewareInvokeCount++
if middlewareInvokeCount != 1 {
panic("second middleware was not invoked in reverse")
}
}
})

router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
})

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

w := new(mockResponseWriter)
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
router.ServeHTTP(w, req)
}

func TestMiddlewaresWithMatchedRoutePath(t *testing.T) {
router := New()
router.SaveMatchedRoutePath = true
route := "/test/:name"
middlewareInvokeCount := 0
router.Use(func(h Handle) Handle {
return func(w http.ResponseWriter, r *http.Request, p Params) {
matchedRoute := p.MatchedRoutePath()
if route != matchedRoute {
t.Fatalf("Inside middleware: Wrong matched route: want %s, got %s", route, matchedRoute)
}
middlewareInvokeCount++
h(w, r, p)
}
})

handlerInvoked := false
handle := func(_ http.ResponseWriter, req *http.Request, ps Params) {
matchedRoute := ps.MatchedRoutePath()
if route != matchedRoute {
t.Fatalf("Inside handle: Wrong matched route: want %s, got %s", route, matchedRoute)
}
if middlewareInvokeCount != 1 {
t.Fatal("middleware did not get invoked")
}
handlerInvoked = true
}
router.Handle(http.MethodGet, route, handle)

defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("failed with panic")
}
}()

w := new(mockResponseWriter)
req, _ := http.NewRequest(http.MethodGet, "/test/abc", nil)
router.ServeHTTP(w, req)

if !handlerInvoked {
t.Fatal("handler did not get invoked")
}
}