diff --git a/docs/reference/flagd-cli/flagd_start.md b/docs/reference/flagd-cli/flagd_start.md index 0a319a4a1..1a853f956 100644 --- a/docs/reference/flagd-cli/flagd_start.md +++ b/docs/reference/flagd-cli/flagd_start.md @@ -14,6 +14,7 @@ flagd start [flags] -H, --context-from-header stringToString add key-value pairs to map header values to context values, where key is Header name, value is context key (default []) -X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default []) -C, --cors-origin strings CORS allowed origins, * will allow all origins + --disable-sync-metadata Disables the getMetadata endpoint of the sync service. Defaults to false, but will default to true in later versions. -h, --help help for start -z, --log-format string Set the logging format, e.g. console or json (default "console") -m, --management-port int32 Port for management operations (default 8014) diff --git a/flagd/cmd/start.go b/flagd/cmd/start.go index 5997df752..83745dd5a 100644 --- a/flagd/cmd/start.go +++ b/flagd/cmd/start.go @@ -36,6 +36,7 @@ const ( syncPortFlagName = "sync-port" syncSocketPathFlagName = "sync-socket-path" uriFlagName = "uri" + disableSyncMetadata = "disable-sync-metadata" contextValueFlagName = "context-value" headerToContextKeyFlagName = "context-from-header" streamDeadlineFlagName = "stream-deadline" @@ -89,6 +90,7 @@ func init() { flags.StringToStringP(headerToContextKeyFlagName, "H", map[string]string{}, "add key-value pairs to map "+ "header values to context values, where key is Header name, value is context key") flags.Duration(streamDeadlineFlagName, 0, "Set a server-side deadline for flagd sync and event streams (default 0, means no deadline).") + flags.Bool(disableSyncMetadata, false, "Disables the getMetadata endpoint of the sync service. Defaults to false, but will default to true in later versions.") bindFlags(flags) } @@ -114,6 +116,7 @@ func bindFlags(flags *pflag.FlagSet) { _ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName)) _ = viper.BindPFlag(headerToContextKeyFlagName, flags.Lookup(headerToContextKeyFlagName)) _ = viper.BindPFlag(streamDeadlineFlagName, flags.Lookup(streamDeadlineFlagName)) + _ = viper.BindPFlag(disableSyncMetadata, flags.Lookup(disableSyncMetadata)) } // startCmd represents the start command @@ -186,6 +189,7 @@ var startCmd = &cobra.Command{ SyncServicePort: viper.GetUint16(syncPortFlagName), SyncServiceSocketPath: viper.GetString(syncSocketPathFlagName), StreamDeadline: viper.GetDuration(streamDeadlineFlagName), + DisableSyncMetadata: viper.GetBool(disableSyncMetadata), SyncProviders: syncProviders, ContextValues: contextValuesToMap, HeaderToContextKeyMappings: headerToContextKeyMappings, diff --git a/flagd/pkg/runtime/from_config.go b/flagd/pkg/runtime/from_config.go index 3e7574429..4a59f1941 100644 --- a/flagd/pkg/runtime/from_config.go +++ b/flagd/pkg/runtime/from_config.go @@ -39,6 +39,7 @@ type Config struct { SyncServicePort uint16 SyncServiceSocketPath string StreamDeadline time.Duration + DisableSyncMetadata bool SyncProviders []sync.SourceConfig CORS []string @@ -116,15 +117,16 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime, // flag sync service flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{ - Logger: logger.WithFields(zap.String("component", "FlagSyncService")), - Port: config.SyncServicePort, - Sources: sources, - Store: s, - ContextValues: config.ContextValues, - KeyPath: config.ServiceKeyPath, - CertPath: config.ServiceCertPath, - SocketPath: config.SyncServiceSocketPath, - StreamDeadline: config.StreamDeadline, + Logger: logger.WithFields(zap.String("component", "FlagSyncService")), + Port: config.SyncServicePort, + Sources: sources, + Store: s, + ContextValues: config.ContextValues, + KeyPath: config.ServiceKeyPath, + CertPath: config.ServiceCertPath, + SocketPath: config.SyncServiceSocketPath, + StreamDeadline: config.StreamDeadline, + DisableSyncMetadata: config.DisableSyncMetadata, }) if err != nil { return nil, fmt.Errorf("error creating sync service: %w", err) diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index 2795cf16f..a8fd01aa2 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -4,9 +4,9 @@ import ( "context" "errors" "fmt" - "maps" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "maps" "time" "buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc" @@ -17,10 +17,11 @@ import ( // syncHandler implements the sync contract type syncHandler struct { - mux *Multiplexer - log *logger.Logger - contextValues map[string]any - deadline time.Duration + mux *Multiplexer + log *logger.Logger + contextValues map[string]any + deadline time.Duration + disableSyncMetadata bool } func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error { @@ -44,7 +45,6 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F for { select { case payload := <-muxPayload: - metadataSrc := make(map[string]any) maps.Copy(metadataSrc, s.contextValues) @@ -58,10 +58,7 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F return fmt.Errorf("error constructing metadata response") } - err = server.Send(&syncv1.SyncFlagsResponse{ - FlagConfiguration: payload.flags, - SyncContext: metadata, - }) + err = server.Send(&syncv1.SyncFlagsResponse{FlagConfiguration: payload.flags, SyncContext: metadata}) if err != nil { s.log.Debug(fmt.Sprintf("error sending stream response: %v", err)) return fmt.Errorf("error sending stream response: %w", err) @@ -97,6 +94,9 @@ func (s syncHandler) FetchAllFlags(_ context.Context, req *syncv1.FetchAllFlagsR func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) ( *syncv1.GetMetadataResponse, error, ) { + if s.disableSyncMetadata { + return nil, status.Error(codes.Unimplemented, "metadata endpoint disabled") + } metadataSrc := make(map[string]any) for k, v := range s.contextValues { metadataSrc[k] = v diff --git a/flagd/pkg/service/flag-sync/handler_test.go b/flagd/pkg/service/flag-sync/handler_test.go index d3a6cf68e..c6a8b0f8a 100644 --- a/flagd/pkg/service/flag-sync/handler_test.go +++ b/flagd/pkg/service/flag-sync/handler_test.go @@ -54,54 +54,55 @@ func TestSyncHandler_SyncFlags(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Shared handler for testing both GetMetadata & SyncFlags methods - flagStore := store.NewFlags() - mp, err := NewMux(flagStore, tt.sources) - require.NoError(t, err) - - handler := syncHandler{ - mux: mp, - contextValues: tt.contextValues, - log: logger.NewLogger(nil, false), - } - - // Test getting metadata from `GetMetadata` (deprecated) - // remove when `GetMetadata` is full removed and deprecated - metaResp, err := handler.GetMetadata(context.Background(), &syncv1.GetMetadataRequest{}) - require.NoError(t, err) - respMetadata := metaResp.GetMetadata().AsMap() - assert.Equal(t, tt.wantMetadata, respMetadata) - - // Test metadata from sync_context - stream := &mockSyncFlagsServer{ - ctx: context.Background(), - mu: sync.Mutex{}, - respReady: make(chan struct{}, 1), - } - - go func() { - err := handler.SyncFlags(&syncv1.SyncFlagsRequest{}, stream) - assert.NoError(t, err) - }() - - select { - case <-stream.respReady: - syncResp := stream.GetLastResponse() - assert.NotNil(t, syncResp) - - syncMetadata := syncResp.GetSyncContext().AsMap() - assert.Equal(t, tt.wantMetadata, syncMetadata) - - // Check the two metadatas are equal + for _, disableSyncMetadata := range []bool{true, false} { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Shared handler for testing both GetMetadata & SyncFlags methods + flagStore := store.NewFlags() + mp, err := NewMux(flagStore, tt.sources) + require.NoError(t, err) + + handler := syncHandler{ + mux: mp, + contextValues: tt.contextValues, + log: logger.NewLogger(nil, false), + disableSyncMetadata: disableSyncMetadata, + } + + // Test getting metadata from `GetMetadata` (deprecated) // remove when `GetMetadata` is full removed and deprecated - assert.Equal(t, respMetadata, syncMetadata) - case <-time.After(time.Second): - t.Fatal("timeout waiting for response") - } - - }) + metaResp, err := handler.GetMetadata(context.Background(), &syncv1.GetMetadataRequest{}) + if !disableSyncMetadata { + require.NoError(t, err) + respMetadata := metaResp.GetMetadata().AsMap() + assert.Equal(t, tt.wantMetadata, respMetadata) + } else { + assert.NotNil(t, err) + } + + // Test metadata from sync_context + stream := &mockSyncFlagsServer{ + ctx: context.Background(), + mu: sync.Mutex{}, + respReady: make(chan struct{}, 1), + } + + go func() { + err := handler.SyncFlags(&syncv1.SyncFlagsRequest{}, stream) + assert.NoError(t, err) + }() + + select { + case <-stream.respReady: + syncResp := stream.GetLastResponse() + assert.NotNil(t, syncResp) + syncMetadata := syncResp.GetSyncContext().AsMap() + assert.Equal(t, tt.wantMetadata, syncMetadata) + case <-time.After(time.Second): + t.Fatal("timeout waiting for response") + } + }) + } } } diff --git a/flagd/pkg/service/flag-sync/sync_service.go b/flagd/pkg/service/flag-sync/sync_service.go index df9548ceb..09977d86c 100644 --- a/flagd/pkg/service/flag-sync/sync_service.go +++ b/flagd/pkg/service/flag-sync/sync_service.go @@ -25,15 +25,16 @@ type ISyncService interface { } type SvcConfigurations struct { - Logger *logger.Logger - Port uint16 - Sources []string - Store *store.State - ContextValues map[string]any - CertPath string - KeyPath string - SocketPath string - StreamDeadline time.Duration + Logger *logger.Logger + Port uint16 + Sources []string + Store *store.State + ContextValues map[string]any + CertPath string + KeyPath string + SocketPath string + StreamDeadline time.Duration + DisableSyncMetadata bool } type Service struct { @@ -82,10 +83,11 @@ func NewSyncService(cfg SvcConfigurations) (*Service, error) { } syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{ - mux: mux, - log: l, - contextValues: cfg.ContextValues, - deadline: cfg.StreamDeadline, + mux: mux, + log: l, + contextValues: cfg.ContextValues, + deadline: cfg.StreamDeadline, + disableSyncMetadata: cfg.DisableSyncMetadata, }) var lis net.Listener diff --git a/flagd/pkg/service/flag-sync/sync_service_test.go b/flagd/pkg/service/flag-sync/sync_service_test.go index e9fe9e353..25eab86fb 100644 --- a/flagd/pkg/service/flag-sync/sync_service_test.go +++ b/flagd/pkg/service/flag-sync/sync_service_test.go @@ -36,143 +36,156 @@ func TestSyncServiceEndToEnd(t *testing.T) { {title: "with unix socket connection", certPath: "", keyPath: "", clientCertPath: "", socketPath: "/tmp/flagd", tls: false, wantStartErr: false}, } - for _, tc := range testCases { - t.Run(fmt.Sprintf("Testing Sync Service %s", tc.title), func(t *testing.T) { - // given - port := 18016 - flagStore, sources := getSimpleFlagStore() - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - service, doneChan, err := createAndStartSyncService(port, sources, flagStore, tc.certPath, tc.keyPath, tc.socketPath, ctx, 0) - - if tc.wantStartErr { - if err == nil { - t.Fatal("expected error creating the service!") + for _, disableSyncMetadata := range []bool{true, false} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("Testing Sync Service %s", tc.title), func(t *testing.T) { + // given + port := 18016 + flagStore, sources := getSimpleFlagStore() + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + service, doneChan, err := createAndStartSyncService( + port, + sources, + flagStore, + tc.certPath, + tc.keyPath, + tc.socketPath, + ctx, + 0, + disableSyncMetadata, + ) + + if tc.wantStartErr { + if err == nil { + t.Fatal("expected error creating the service!") + } + return + } else if err != nil { + t.Fatal("unexpected error creating the service: %w", err) + return } - return - } else if err != nil { - t.Fatal("unexpected error creating the service: %w", err) - return - } - // when - derive a client for sync service - serviceClient := getSyncClient(t, tc.clientCertPath, tc.socketPath, tc.tls, port, ctx) - - // then + // when - derive a client for sync service + serviceClient := getSyncClient(t, tc.clientCertPath, tc.socketPath, tc.tls, port, ctx) - // sync flags request - flags, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{}) - if err != nil { - t.Fatal(fmt.Printf("error from sync request: %v", err)) - return - } + // then - syncRsp, err := flags.Recv() - if err != nil { - t.Fatal(fmt.Printf("stream error: %v", err)) - return - } + // sync flags request + flags, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{}) + if err != nil { + t.Fatal(fmt.Printf("error from sync request: %v", err)) + return + } - if len(syncRsp.GetFlagConfiguration()) == 0 { - t.Error("expected non empty sync response, but got empty") - } + syncRsp, err := flags.Recv() + if err != nil { + t.Fatal(fmt.Printf("stream error: %v", err)) + return + } - // checks sync context actually set - syncContext := syncRsp.GetSyncContext() - if syncContext == nil { - t.Fatal("expected sync_context in SyncFlagsResponse, but got nil") - } + if len(syncRsp.GetFlagConfiguration()) == 0 { + t.Error("expected non empty sync response, but got empty") + } - syncAsMap := syncContext.AsMap() - if syncAsMap["sources"] == nil { - t.Fatalf("expected sources in sync_context, but got nil") - } + // checks sync context actually set + syncContext := syncRsp.GetSyncContext() + if syncContext == nil { + t.Fatal("expected sync_context in SyncFlagsResponse, but got nil") + } - sourcesStr := syncAsMap["sources"].(string) - sourcesArray := strings.Split(sourcesStr, ",") - sort.Strings(sourcesArray) + syncAsMap := syncContext.AsMap() + if syncAsMap["sources"] == nil { + t.Fatalf("expected sources in sync_context, but got nil") + } - expectedSources := []string{"A", "B", "C"} - if !reflect.DeepEqual(sourcesArray, expectedSources) { - t.Fatalf("sources entry in sync_context does not match expected: got %v, want %v", sourcesArray, expectedSources) - } + sourcesStr := syncAsMap["sources"].(string) + sourcesArray := strings.Split(sourcesStr, ",") + sort.Strings(sourcesArray) - // validate emits - dataReceived := make(chan interface{}) - go func() { - _, err := flags.Recv() - if err != nil { - return + expectedSources := []string{"A", "B", "C"} + if !reflect.DeepEqual(sourcesArray, expectedSources) { + t.Fatalf("sources entry in sync_context does not match expected: got %v, want %v", sourcesArray, expectedSources) } - dataReceived <- nil - }() - - // Emit as a resync - service.Emit(true, "A") + // validate emits + dataReceived := make(chan interface{}) + go func() { + _, err := flags.Recv() + if err != nil { + return + } - select { - case <-dataReceived: - t.Fatal("expected no data as this is a resync") - case <-time.After(1 * time.Second): - break - } + dataReceived <- nil + }() - // Emit as a resync - service.Emit(false, "A") + // Emit as a resync + service.Emit(true, "A") - select { - case <-dataReceived: - break - case <-time.After(1 * time.Second): - t.Fatal("expected data but timeout waiting for sync") - } + select { + case <-dataReceived: + t.Fatal("expected no data as this is a resync") + case <-time.After(1 * time.Second): + break + } - // fetch all flags - allRsp, err := serviceClient.FetchAllFlags(ctx, &v1.FetchAllFlagsRequest{}) - if err != nil { - t.Fatal(fmt.Printf("fetch all error: %v", err)) - return - } + // Emit as a resync + service.Emit(false, "A") - if allRsp.GetFlagConfiguration() != syncRsp.GetFlagConfiguration() { - t.Errorf("expected both sync and fetch all responses to be same, but got %s from sync & %s from fetch all", - syncRsp.GetFlagConfiguration(), allRsp.GetFlagConfiguration()) - } + select { + case <-dataReceived: + break + case <-time.After(1 * time.Second): + t.Fatal("expected data but timeout waiting for sync") + } - // metadata request - metadataRsp, err := serviceClient.GetMetadata(ctx, &v1.GetMetadataRequest{}) - if err != nil { - t.Fatal(fmt.Printf("metadata error: %v", err)) - return - } + // fetch all flags + allRsp, err := serviceClient.FetchAllFlags(ctx, &v1.FetchAllFlagsRequest{}) + if err != nil { + t.Fatal(fmt.Printf("fetch all error: %v", err)) + return + } - asMap := metadataRsp.GetMetadata().AsMap() + if allRsp.GetFlagConfiguration() != syncRsp.GetFlagConfiguration() { + t.Errorf("expected both sync and fetch all responses to be same, but got %s from sync & %s from fetch all", + syncRsp.GetFlagConfiguration(), allRsp.GetFlagConfiguration()) + } - // expect `sources` to be present - if asMap["sources"] == nil { - t.Fatal("expected sources entry in the metadata, but got nil") - } + // metadata request + metadataRsp, err := serviceClient.GetMetadata(ctx, &v1.GetMetadataRequest{}) - if asMap["sources"] != "A,B,C" { - t.Fatal("incorrect sources entry in metadata") - } + if disableSyncMetadata { + if err == nil { + t.Fatal(fmt.Printf("getMetadata disabled, error should not be nil")) + return + } + } else { + asMap := metadataRsp.GetMetadata().AsMap() + // expect `sources` to be present + if asMap["sources"] == nil { + t.Fatal("expected sources entry in the metadata, but got nil") + } + if asMap["sources"] != "A,B,C" { + t.Fatal("incorrect sources entry in metadata") + } + } - // validate shutdown from context cancellation - go func() { - cancelFunc() - }() + // validate shutdown from context cancellation + go func() { + cancelFunc() + }() - select { - case <-doneChan: - // exit successful - return - case <-time.After(2 * time.Second): - t.Fatal("service did not exist within sufficient timeframe") - } - }) + select { + case <-doneChan: + // exit successful + return + case <-time.After(2 * time.Second): + t.Fatal("service did not exist within sufficient timeframe") + } + }) + } } } @@ -198,7 +211,7 @@ func TestSyncServiceDeadlineEndToEnd(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - _, _, err := createAndStartSyncService(port, sources, flagStore, certPath, keyPath, socketPath, ctx, tc.deadline) + _, _, err := createAndStartSyncService(port, sources, flagStore, certPath, keyPath, socketPath, ctx, tc.deadline, false) if err != nil { t.Fatal("error creating sync service") } @@ -256,16 +269,27 @@ func TestSyncServiceDeadlineEndToEnd(t *testing.T) { } } -func createAndStartSyncService(port int, sources []string, store *store.State, certPath string, keyPath string, socketPath string, ctx context.Context, deadline time.Duration) (*Service, chan interface{}, error) { +func createAndStartSyncService( + port int, + sources []string, + store *store.State, + certPath string, + keyPath string, + socketPath string, + ctx context.Context, + deadline time.Duration, + disableSyncMetadata bool, +) (*Service, chan interface{}, error) { service, err := NewSyncService(SvcConfigurations{ - Logger: logger.NewLogger(nil, false), - Port: uint16(port), - Sources: sources, - Store: store, - CertPath: certPath, - KeyPath: keyPath, - SocketPath: socketPath, - StreamDeadline: deadline, + Logger: logger.NewLogger(nil, false), + Port: uint16(port), + Sources: sources, + Store: store, + CertPath: certPath, + KeyPath: keyPath, + SocketPath: socketPath, + StreamDeadline: deadline, + DisableSyncMetadata: disableSyncMetadata, }) if err != nil { return nil, nil, err