diff --git a/pkg/cvo/metrics.go b/pkg/cvo/metrics.go index ac2d32fcb..a0f81593c 100644 --- a/pkg/cvo/metrics.go +++ b/pkg/cvo/metrics.go @@ -12,14 +12,19 @@ import ( "net/http" "os" "path/filepath" + "strings" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + authenticationv1 "k8s.io/api/authentication/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/util/sets" + authenticationclientsetv1 "k8s.io/client-go/kubernetes/typed/authentication/v1" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" @@ -127,15 +132,75 @@ type asyncResult struct { error error } -func createHttpServer() *http.Server { +func createHttpServer(ctx context.Context, client *authenticationclientsetv1.AuthenticationV1Client) *http.Server { + auth := authHandler{downstream: promhttp.Handler(), ctx: ctx, client: client.TokenReviews()} handler := http.NewServeMux() - handler.Handle("/metrics", promhttp.Handler()) + handler.Handle("/metrics", &auth) server := &http.Server{ Handler: handler, } return server } +type tokenReviewInterface interface { + Create(ctx context.Context, tokenReview *authenticationv1.TokenReview, opts metav1.CreateOptions) (*authenticationv1.TokenReview, error) +} + +type authHandler struct { + downstream http.Handler + ctx context.Context + client tokenReviewInterface +} + +func (a *authHandler) authorize(token string) (bool, error) { + tr := &authenticationv1.TokenReview{ + Spec: authenticationv1.TokenReviewSpec{ + Token: token, + }, + } + result, err := a.client.Create(a.ctx, tr, metav1.CreateOptions{}) + if err != nil { + return false, fmt.Errorf("failed to check token: %w", err) + } + isAuthenticated := result.Status.Authenticated + isPrometheus := result.Status.User.Username == "system:serviceaccount:openshift-monitoring:prometheus-k8s" + if !isAuthenticated { + klog.V(4).Info("The token cannot be authenticated.") + } else if !isPrometheus { + klog.V(4).Infof("Access the metrics from the unexpected user %s is denied.", result.Status.User.Username) + } + return isAuthenticated && isPrometheus, nil +} + +func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "failed to get the Authorization header", http.StatusUnauthorized) + return + } + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == "" { + http.Error(w, "empty Bearer token", http.StatusUnauthorized) + return + } + if token == authHeader { + http.Error(w, "failed to get the Bearer token", http.StatusUnauthorized) + return + } + + authorized, err := a.authorize(token) + if err != nil { + klog.Warningf("Failed to authorize token: %v", err) + http.Error(w, "failed to authorize due to an internal error", http.StatusInternalServerError) + return + } + if !authorized { + http.Error(w, "failed to authorize", http.StatusUnauthorized) + return + } + a.downstream.ServeHTTP(w, r) +} + func shutdownHttpServer(parentCtx context.Context, svr *http.Server) { ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second) defer cancel() @@ -181,7 +246,7 @@ func handleServerResult(result asyncResult, lastLoopError error) error { // Also detects changes to metrics certificate files upon which // the metrics HTTP server is shutdown and recreated with a new // TLS configuration. -func RunMetrics(runContext context.Context, shutdownContext context.Context, listenAddress, certFile, keyFile string) error { +func RunMetrics(runContext context.Context, shutdownContext context.Context, listenAddress, certFile, keyFile string, restConfig *rest.Config) error { var tlsConfig *tls.Config if listenAddress != "" { var err error @@ -192,7 +257,13 @@ func RunMetrics(runContext context.Context, shutdownContext context.Context, lis } else { return errors.New("TLS configuration is required to serve metrics") } - server := createHttpServer() + + client, err := authenticationclientsetv1.NewForConfig(restConfig) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + + server := createHttpServer(runContext, client) resultChannel := make(chan asyncResult, 1) resultChannelCount := 1 @@ -246,7 +317,7 @@ func RunMetrics(runContext context.Context, shutdownContext context.Context, lis case result := <-resultChannel: // crashed before a shutdown was requested or metrics server recreated if restartServer { klog.Info("Creating metrics server with updated TLS configuration.") - server = createHttpServer() + server = createHttpServer(runContext, client) go startListening(server, tlsConfig, listenAddress, resultChannel) restartServer = false continue diff --git a/pkg/cvo/metrics_test.go b/pkg/cvo/metrics_test.go index e813e3c02..085b92bdd 100644 --- a/pkg/cvo/metrics_test.go +++ b/pkg/cvo/metrics_test.go @@ -1,15 +1,22 @@ package cvo import ( + "context" "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" "sort" "strings" "testing" "time" "github.com/davecgh/go-spew/spew" + "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" + authenticationv1 "k8s.io/api/authentication/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/record" @@ -1009,3 +1016,154 @@ func metricParts(t *testing.T, metric prometheus.Metric, labels ...string) strin } return strings.Join(parts, " ") } + +type fakeClient struct { +} + +func (c *fakeClient) Create(_ context.Context, tokenReview *authenticationv1.TokenReview, _ metav1.CreateOptions) (*authenticationv1.TokenReview, error) { + if tokenReview != nil { + ret := tokenReview.DeepCopy() + if tokenReview.Spec.Token == "good" { + ret.Status.Authenticated = true + ret.Status.User.Username = "system:serviceaccount:openshift-monitoring:prometheus-k8s" + } + if tokenReview.Spec.Token == "authenticated" { + ret.Status.Authenticated = true + } + if tokenReview.Spec.Token == "error" { + return nil, errors.New("fake error") + } + return ret, nil + } + return nil, errors.New("nil input") +} + +type okHandler struct { +} + +func (h *okHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "ok") +} + +func Test_authHandler(t *testing.T) { + tests := []struct { + name string + handler *authHandler + method string + body io.Reader + headerKey string + headerValue string + expectedStatusCode int + expectedBody string + }{ + { + name: "good", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "Bearer good", + expectedStatusCode: http.StatusOK, + expectedBody: "ok", + }, + { + name: "empty bearer token", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "Bearer ", + expectedStatusCode: 401, + expectedBody: "empty Bearer token\n", + }, + { + name: "authenticated", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "Bearer authenticated", + expectedStatusCode: 401, + expectedBody: "failed to authorize\n", + }, + { + name: "bad", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "Bearer bad", + expectedStatusCode: 401, + expectedBody: "failed to authorize\n", + }, + { + name: "failed to get the Authorization header", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + expectedStatusCode: 401, + expectedBody: "failed to get the Authorization header\n", + }, + { + name: "failed to get the Bearer token", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "xxx bad", + expectedStatusCode: 401, + expectedBody: "failed to get the Bearer token\n", + }, + { + name: "error", + handler: &authHandler{ + ctx: context.TODO(), + downstream: &okHandler{}, + client: &fakeClient{}, + }, + method: "GET", + headerKey: "Authorization", + headerValue: "Bearer error", + expectedStatusCode: 500, + expectedBody: "failed to authorize due to an internal error\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + + req, err := http.NewRequest(tt.method, "url-not-important", tt.body) + if err != nil { + t.Fatal(err) + } + req.Header.Set(tt.headerKey, tt.headerValue) + + tt.handler.ServeHTTP(rr, req) + if diff := cmp.Diff(tt.expectedStatusCode, rr.Code); diff != "" { + t.Errorf("%s: status differs from expected:\n%s", tt.name, diff) + } + + if diff := cmp.Diff(tt.expectedBody, rr.Body.String()); diff != "" { + t.Errorf("%s: body differs from expected:\n%s", tt.name, diff) + } + }) + } +} diff --git a/pkg/start/start.go b/pkg/start/start.go index 958b827b4..bf38f68fa 100644 --- a/pkg/start/start.go +++ b/pkg/start/start.go @@ -357,7 +357,7 @@ func (o *Options) run(ctx context.Context, controllerCtx *Context, lock resource resultChannelCount++ go func() { defer utilruntime.HandleCrash() - err := cvo.RunMetrics(postMainContext, shutdownContext, o.ListenAddr, o.ServingCertFile, o.ServingKeyFile) + err := cvo.RunMetrics(postMainContext, shutdownContext, o.ListenAddr, o.ServingCertFile, o.ServingKeyFile, restConfig) resultChannel <- asyncResult{name: "metrics server", error: err} }() }