From 4f2bdc6b43755402c3364e895b64acaa7f7849a1 Mon Sep 17 00:00:00 2001 From: Ajilal Date: Sat, 5 Nov 2022 22:56:14 +0530 Subject: [PATCH 1/2] Middleware implementation part-1: functions created --- router.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/router.go b/router.go index 1eab403d..72776da7 100644 --- a/router.go +++ b/router.go @@ -88,6 +88,14 @@ import ( // wildcards (path variables). type Handle func(http.ResponseWriter, *http.Request, Params) +// Middleware is a function that accepts a httprouter.Handle function as input +// and returns another httprouter.Handle function. These are used to +// wrap the httprouter.Handler function to implement cross cutting functionalities like +// authentication, logging etc.. +// +// httprouter.Router exposes the function Use() to add Middleware functions to it. +type Middleware func(Handle) Handle + // Param is a single URL parameter, consisting of a key and a value. type Param struct { Key string @@ -141,6 +149,12 @@ type Router struct { paramsPool sync.Pool maxParams uint16 + // Specifies the maximum number of middlewares. Setting this value will + // give better performance while adding the middlewares to the Router. + // Default value 0 + MaxMiddlewares uint8 + middlewares []Middleware + // If enabled, adds the matched route path onto the http.Request context // before invoking the handler. // The matched route path is only added to handlers of routes that were @@ -302,6 +316,8 @@ func (r *Router) Handle(method, path string, handle Handle) { panic("handle must not be nil") } + handle = r.wrapMiddlewaresAroundHandler(handle) + if r.SaveMatchedRoutePath { varsCount++ handle = r.saveMatchedRoutePath(path, handle) @@ -538,3 +554,31 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.NotFound(w, req) } } + +// Function to add middlewares of type httprouter.Handle to the Router. +// Middlewares can be added before and after definging the routes. +// This function is not concurrency safe. +// +// The execution of the handlers will be in the following order. +// +// Middlewares added to Router before the route handler -> +// Route handler -> Middlewares added to the Router after the route handler +func (r *Router) Use(mw Middleware) { + // Lazy initialization of the middleware handles slice + if r.middlewares == nil { + r.middlewares = make([]Middleware, 0, r.MaxMiddlewares) + } + r.middlewares = append(r.middlewares, mw) +} + +// This function sandwiches the specified handler between the middlewares +// handlers. The middlewares wrapping will happen in the reverse order of +// the middleware addition so that the execution of the middlewares will +// keep their order. +func (r *Router) wrapMiddlewaresAroundHandler (h Handle) Handle { + middlewareSize := len(r.middlewares) + for i := middlewareSize - 1; i >= 0; i-- { + h = r.middlewares[i](h) + } + return h +} From 17c5ffde4637b66666427fa46e1cc91c4b0066fb Mon Sep 17 00:00:00 2001 From: Ajilal Date: Sun, 6 Nov 2022 18:20:09 +0530 Subject: [PATCH 2/2] Middleware : added test cases and removed MaxMiddlewares field from Router --- router.go | 27 +++---- router_test.go | 211 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 17 deletions(-) diff --git a/router.go b/router.go index 72776da7..513e6901 100644 --- a/router.go +++ b/router.go @@ -149,11 +149,7 @@ type Router struct { paramsPool sync.Pool maxParams uint16 - // Specifies the maximum number of middlewares. Setting this value will - // give better performance while adding the middlewares to the Router. - // Default value 0 - MaxMiddlewares uint8 - middlewares []Middleware + middlewares []Middleware // If enabled, adds the matched route path onto the http.Request context // before invoking the handler. @@ -555,24 +551,21 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } -// Function to add middlewares of type httprouter.Handle to the Router. -// Middlewares can be added before and after definging the routes. -// This function is not concurrency safe. +// Function to add middleware of type httprouter.Middleware to the Router. +// Middlewares are to be added to the router before defining the routes. // -// The execution of the handlers will be in the following order. -// -// Middlewares added to Router before the route handler -> -// Route handler -> Middlewares added to the Router after the route handler -func (r *Router) Use(mw Middleware) { - // Lazy initialization of the middleware handles slice +// The middlewares will then wrap the request handlers and be run in the order +// they were added to the router. +func (r *Router) Use (mw Middleware) { + // Lazy initialization of the middlewares slice if r.middlewares == nil { - r.middlewares = make([]Middleware, 0, r.MaxMiddlewares) + r.middlewares = make([]Middleware, 0, 1) } r.middlewares = append(r.middlewares, mw) } -// This function sandwiches the specified handler between the middlewares -// handlers. The middlewares wrapping will happen in the reverse order of +// This function sandwiches the specified handler between the middlewares. +// The middlewares wrapping will happen in the reverse order of // the middleware addition so that the execution of the middlewares will // keep their order. func (r *Router) wrapMiddlewaresAroundHandler (h Handle) Handle { diff --git a/router_test.go b/router_test.go index ae7d2435..e59653c5 100644 --- a/router_test.go +++ b/router_test.go @@ -697,3 +697,214 @@ func TestRouterServeFiles(t *testing.T) { t.Error("serving file failed") } } + +func TestNoMiddleware(t *testing.T) { + router := New() + router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) {}) + + w := new(mockResponseWriter) + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + router.ServeHTTP(w, req) + + if router.middlewares != nil { + t.Fatal("middleware array initialized even before adding a middleware") + } +} + +func TestMiddlewareWrapsTheRequestHandler(t *testing.T) { + router := New() + middlewareInvoked := false + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + middlewareInvoked = true + h(w, r, p) + } + }) + requestHandlerInvoked := false + router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) { + if !middlewareInvoked { + panic("middleware did not get invoked") + } + requestHandlerInvoked = true + }) + + w := new(mockResponseWriter) + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + router.ServeHTTP(w, req) + + if !middlewareInvoked { + t.Fatal("middleware did not get invoked") + } + if !requestHandlerInvoked { + t.Fatal("request handler did not get invoked") + } +} + +func TestMiddlewareAndMultiplePaths(t *testing.T) { + router := New() + middlewareInvokeCount := 0 + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + middlewareInvokeCount++ + h(w, r, p) + } + }) + + router.Handle(http.MethodGet, "/test1", func(_ http.ResponseWriter, _ *http.Request, _ Params) { + if middlewareInvokeCount != 1 { + panic("middleware did not get invoked for test1") + } + }) + router.Handle(http.MethodGet, "/test2", func(_ http.ResponseWriter, _ *http.Request, _ Params) { + if middlewareInvokeCount != 2 { + panic("middleware did not get invoked for test2") + } + }) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + w1 := new(mockResponseWriter) + req1, _ := http.NewRequest(http.MethodGet, "/test1", nil) + router.ServeHTTP(w1, req1) + + w2 := new(mockResponseWriter) + req2, _ := http.NewRequest(http.MethodGet, "/test2", nil) + router.ServeHTTP(w2, req2) +} + +func TestMiddlewareInvokeOrder(t *testing.T) { + router := New() + middlewareInvokeCount := 0 + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + middlewareInvokeCount++ + if middlewareInvokeCount != 1 { + panic("first middleware was not invoked first") + } + h(w, r, p) + } + }) + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + middlewareInvokeCount++ + if middlewareInvokeCount != 2 { + panic("second middleware was not invoked second") + } + h(w, r, p) + } + }) + + router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) { + if middlewareInvokeCount != 2 { + panic("both middlewares did not get invoked") + } + }) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + w := new(mockResponseWriter) + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + router.ServeHTTP(w, req) +} + +func TestMiddlewareInvokeOrderAfterHandlerDone(t *testing.T) { + router := New() + middlewareInvokeCount := 0 + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + h(w, r, p) + middlewareInvokeCount++ + if middlewareInvokeCount != 2 { + panic("first middleware was not invoked in reverse") + } + } + }) + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + h(w, r, p) + middlewareInvokeCount++ + if middlewareInvokeCount != 1 { + panic("second middleware was not invoked in reverse") + } + } + }) + + router.Handle(http.MethodGet, "/test", func(_ http.ResponseWriter, _ *http.Request, _ Params) { + }) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + w := new(mockResponseWriter) + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + router.ServeHTTP(w, req) +} + +func TestMiddlewaresWithMatchedRoutePath(t *testing.T) { + router := New() + router.SaveMatchedRoutePath = true + route := "/test/:name" + middlewareInvokeCount := 0 + router.Use(func(h Handle) Handle { + return func(w http.ResponseWriter, r *http.Request, p Params) { + matchedRoute := p.MatchedRoutePath() + if route != matchedRoute { + t.Fatalf("Inside middleware: Wrong matched route: want %s, got %s", route, matchedRoute) + } + middlewareInvokeCount++ + h(w, r, p) + } + }) + + handlerInvoked := false + handle := func(_ http.ResponseWriter, req *http.Request, ps Params) { + matchedRoute := ps.MatchedRoutePath() + if route != matchedRoute { + t.Fatalf("Inside handle: Wrong matched route: want %s, got %s", route, matchedRoute) + } + if middlewareInvokeCount != 1 { + t.Fatal("middleware did not get invoked") + } + handlerInvoked = true + } + router.Handle(http.MethodGet, route, handle) + + defer func() { + if rcv := recover(); rcv != nil { + t.Fatal("failed with panic") + } + }() + + w := new(mockResponseWriter) + req, _ := http.NewRequest(http.MethodGet, "/test/abc", nil) + router.ServeHTTP(w, req) + + if !handlerInvoked { + t.Fatal("handler did not get invoked") + } +}