1
1
package openapi3middleware
2
2
3
3
import (
4
+ "context"
4
5
"encoding/json"
5
6
"errors"
6
7
"fmt"
@@ -9,6 +10,8 @@ import (
9
10
"github.com/getkin/kin-openapi/openapi3"
10
11
"github.com/getkin/kin-openapi/openapi3filter"
11
12
"github.com/getkin/kin-openapi/routers"
13
+ "go.opentelemetry.io/otel"
14
+ "go.opentelemetry.io/otel/trace"
12
15
)
13
16
14
17
type middleware = func (next http.Handler ) http.Handler
@@ -19,6 +22,7 @@ type MiddlewareOptions struct {
19
22
ReportFindRouteError func (w http.ResponseWriter , r * http.Request , err error )
20
23
ReportRequestValidationError func (w http.ResponseWriter , r * http.Request , err error )
21
24
ReportResponseValidationError func (w http.ResponseWriter , r * http.Request , err error )
25
+ TracerProvider trace.TracerProvider
22
26
}
23
27
24
28
func (o MiddlewareOptions ) reportFindRouteError (w http.ResponseWriter , r * http.Request , err error ) {
@@ -60,13 +64,18 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
60
64
return func (next http.Handler ) http.Handler {
61
65
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
62
66
ctx := r .Context ()
67
+ ctx , span := getTracer (ctx , options ).Start (ctx , "ResponseValidation" )
68
+ defer span .End ()
63
69
irw := newBufferingResponseWriter (w )
64
- next .ServeHTTP (irw , r )
70
+ next .ServeHTTP (irw , r . WithContext ( ctx ) )
65
71
ri , err := buildRequestValidationInputFromRequest (options .Router , r , options .ValidationOptions )
66
72
if frErr := new (findRouteErr ); errors .As (err , & frErr ) {
67
- options .reportFindRouteError (w , r , frErr .Unwrap ())
73
+ actualErr := frErr .Unwrap ()
74
+ span .RecordError (actualErr )
75
+ options .reportFindRouteError (w , r , actualErr )
68
76
return
69
77
} else if err != nil {
78
+ span .RecordError (err )
70
79
respondErrorJSON (w , http .StatusInternalServerError , err )
71
80
return
72
81
}
@@ -81,6 +90,7 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
81
90
bodyBytes := irw .buf .Bytes ()
82
91
input .SetBodyBytes (bodyBytes )
83
92
if err := openapi3filter .ValidateResponse (ctx , input ); err != nil {
93
+ span .RecordError (err )
84
94
options .reportRespError (w , r , err )
85
95
return
86
96
}
@@ -94,20 +104,26 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
94
104
func WithRequestValidation (options MiddlewareOptions ) middleware {
95
105
return func (next http.Handler ) http.Handler {
96
106
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
107
+ ctx := r .Context ()
108
+ ctx , span := getTracer (ctx , options ).Start (ctx , "RequestValidation" )
109
+ defer span .End ()
97
110
input , err := buildRequestValidationInputFromRequest (options .Router , r , options .ValidationOptions )
98
111
if frErr := new (findRouteErr ); errors .As (err , & frErr ) {
99
- options .reportFindRouteError (w , r , frErr .Unwrap ())
112
+ actualErr := frErr .Unwrap ()
113
+ span .RecordError (actualErr )
114
+ options .reportFindRouteError (w , r , actualErr )
100
115
return
101
116
} else if err != nil {
117
+ span .RecordError (err )
102
118
respondErrorJSON (w , http .StatusInternalServerError , err )
103
119
return
104
120
}
105
- ctx := r .Context ()
106
121
if err := openapi3filter .ValidateRequest (ctx , input ); err != nil {
122
+ span .RecordError (err )
107
123
options .reportReqError (w , r , err )
108
124
return
109
125
}
110
- next .ServeHTTP (w , r )
126
+ next .ServeHTTP (w , r . WithContext ( ctx ) )
111
127
})
112
128
}
113
129
}
@@ -218,3 +234,17 @@ func respondJSON(w http.ResponseWriter, statusCode int, payload interface{}) err
218
234
w .WriteHeader (statusCode )
219
235
return json .NewEncoder (w ).Encode (payload )
220
236
}
237
+
238
+ const tracerName = "github.com/aereal/go-openapi3-validation-middleware"
239
+
240
+ func getTracer (ctx context.Context , opts MiddlewareOptions ) trace.Tracer {
241
+ tp := opts .TracerProvider
242
+ if tp == nil {
243
+ if span := trace .SpanFromContext (ctx ); span .SpanContext ().IsValid () {
244
+ tp = span .TracerProvider ()
245
+ } else {
246
+ tp = otel .GetTracerProvider ()
247
+ }
248
+ }
249
+ return tp .Tracer (tracerName )
250
+ }
0 commit comments