diff --git a/handlers.go b/handlers.go index ba5ae7a..3d31561 100644 --- a/handlers.go +++ b/handlers.go @@ -2,7 +2,7 @@ package zenrpc import ( "encoding/json" - "io/ioutil" + "io" "net/http" "strings" "time" @@ -16,7 +16,7 @@ type Printer interface { // ServeHTTP process JSON-RPC 2.0 requests via HTTP. // http://www.simple-is-better.org/json-rpc/transport_http.html -func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // check for CORS GET & POST requests if s.options.AllowCORS { w.Header().Set("Access-Control-Allow-Origin", "*") @@ -54,14 +54,14 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // ok, method is POST and content-type is application/json, process body - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) var data interface{} if err != nil { s.printf("read request body failed with err=%v", err) data = NewResponseError(nil, ParseError, "", nil) } else { - data = s.process(newRequestContext(r.Context(), r), b) + data = s.process(newRequestResponseContext(r.Context(), r, w), b) } // if responses is empty -> all requests are notifications -> exit immediately @@ -86,7 +86,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // ServeWS processes JSON-RPC 2.0 requests via Gorilla WebSocket. // https://github.com/gorilla/websocket/blob/master/examples/echo/ -func (s Server) ServeWS(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeWS(w http.ResponseWriter, r *http.Request) { c, err := s.options.Upgrader.Upgrade(w, r, nil) if err != nil { s.printf("upgrade connection failed with err=%v", err) @@ -107,7 +107,7 @@ func (s Server) ServeWS(w http.ResponseWriter, r *http.Request) { break } - data, err := s.Do(newRequestContext(r.Context(), r), message) + data, err := s.Do(newRequestResponseContext(r.Context(), r, w), message) if err != nil { s.printf("marshal json response failed with err=%v", err) c.WriteControl(websocket.CloseInternalServerErr, nil, time.Time{}) diff --git a/server.go b/server.go index c199773..555cdae 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( "unicode" "github.com/gorilla/websocket" + "github.com/semrush/zenrpc/v2/smd" ) @@ -26,6 +27,9 @@ const ( // context key for http.Request object. requestKey contextKey = "request" + // context key for http.ResponseWriter implementation. + responseWriterKey contextKey = "responseWriter" + // context key for namespace. namespaceKey contextKey = "namespace" @@ -165,7 +169,7 @@ func (s *Server) process(ctx context.Context, message json.RawMessage) interface } // processBatch process batch requests with context. -func (s Server) processBatch(ctx context.Context, requests []Request) []Response { +func (s *Server) processBatch(ctx context.Context, requests []Request) []Response { reqLen := len(requests) // running requests in batch asynchronously @@ -206,7 +210,7 @@ func (s Server) processBatch(ctx context.Context, requests []Request) []Response } // processRequest processes a single request in service invoker. -func (s Server) processRequest(ctx context.Context, req Request) Response { +func (s *Server) processRequest(ctx context.Context, req Request) Response { // checks for json-rpc version and method if req.Version != Version || req.Method == "" { return NewResponseError(req.ID, InvalidRequest, "", nil) @@ -248,18 +252,18 @@ func (s Server) processRequest(ctx context.Context, req Request) Response { } // Do process JSON-RPC 2.0 request, invokes correct method for namespace and returns JSON-RPC 2.0 Response or marshaller error. -func (s Server) Do(ctx context.Context, req []byte) ([]byte, error) { +func (s *Server) Do(ctx context.Context, req []byte) ([]byte, error) { return json.Marshal(s.process(ctx, req)) } -func (s Server) printf(format string, v ...interface{}) { +func (s *Server) printf(format string, v ...interface{}) { if s.logger != nil { s.logger.Printf(format, v...) } } // SMD returns Service Mapping Description object with all registered methods. -func (s Server) SMD() smd.Schema { +func (s *Server) SMD() smd.Schema { sch := smd.Schema{ Transport: "POST", Envelope: "JSON-RPC-2.0", @@ -346,9 +350,9 @@ func ConvertToObject(keys []string, params json.RawMessage) (json.RawMessage, er return buf.Bytes(), nil } -// newRequestContext creates new context with http.Request. -func newRequestContext(ctx context.Context, req *http.Request) context.Context { - return context.WithValue(ctx, requestKey, req) +// newRequestResponseContext creates new context with http.Request and http.ResponseWriter. +func newRequestResponseContext(ctx context.Context, req *http.Request, resp http.ResponseWriter) context.Context { + return context.WithValue(context.WithValue(ctx, responseWriterKey, resp), requestKey, req) } // RequestFromContext returns http.Request from context. @@ -357,6 +361,12 @@ func RequestFromContext(ctx context.Context) (*http.Request, bool) { return r, ok } +// ResponseHeadersFromContext returns headers map to be sent with HTTP response of passed context. +func ResponseHeadersFromContext(ctx context.Context) (http.Header, bool) { + r, ok := ctx.Value(responseWriterKey).(http.ResponseWriter) + return r.Header(), ok +} + // newNamespaceContext creates new context with current method namespace. func newNamespaceContext(ctx context.Context, namespace string) context.Context { return context.WithValue(ctx, namespaceKey, namespace)