diff --git a/ldclient_gocontext.go b/ldclient_gocontext.go new file mode 100644 index 00000000..fdc7fcf6 --- /dev/null +++ b/ldclient_gocontext.go @@ -0,0 +1,61 @@ +package ldclient + +import "context" + +type scopedClientKey struct{} + +// GoContextWithScopedClient adds a scoped client to the Go context. This can be +// used to pass a scoped client to a function or goroutine that might not +// otherwise have access to it: +// +// scopedClient := ld.NewScopedClient(client, ldUserContext) +// ctx := ld.GoContextWithScopedClient(context.Background(), scopedClient) +// otherFunction(ctx) +// +// This function is not stable, and not subject to any backwards compatibility +// guarantees or semantic versioning. It is not suitable for production usage. Do +// not use it. You have been warned. +func GoContextWithScopedClient(ctx context.Context, client *LDScopedClient) context.Context { + return context.WithValue(ctx, scopedClientKey{}, client) +} + +// GetScopedClient retrieves a scoped client from the Go context that was set +// with GoContextWithScopedClient, if present. If not present, returns nil and +// false. +// +// func logicWithFeatureFlag(ctx context.Context) { +// scopedClient, ok := ld.GetScopedClient(ctx) +// isFeatureEnabled := false // default value if scoped client is not available +// if ok { +// isFeatureEnabled, err = scopedClient.BoolVariation("my-flag", false) +// // handle err as appropriate... +// } +// } +// +// This function is not stable, and not subject to any backwards compatibility +// guarantees or semantic versioning. It is not suitable for production usage. Do +// not use it. You have been warned. +func GetScopedClient(ctx context.Context) (*LDScopedClient, bool) { + client, ok := ctx.Value(scopedClientKey{}).(*LDScopedClient) + return client, ok +} + +// MustGetScopedClient retrieves a scoped client from the Go context that was set +// with GoContextWithScopedClient, or panics if not present. +// +// func logicWithFeatureFlag(ctx context.Context) { +// scopedClient := ld.MustGetScopedClient(ctx) +// isFeatureEnabled, err := scopedClient.BoolVariation("my-flag", false) +// // handle err as appropriate... +// } +// +// This function is not stable, and not subject to any backwards compatibility +// guarantees or semantic versioning. It is not suitable for production usage. Do +// not use it. You have been warned. +func MustGetScopedClient(ctx context.Context) *LDScopedClient { + client, ok := GetScopedClient(ctx) + if !ok { + panic("No scoped client found in context") + } + return client +} diff --git a/ldclient_gocontext_test.go b/ldclient_gocontext_test.go new file mode 100644 index 00000000..ec0de003 --- /dev/null +++ b/ldclient_gocontext_test.go @@ -0,0 +1,40 @@ +package ldclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetScopedClient(t *testing.T) { + t.Run("returns client from context", func(t *testing.T) { + origCtx := context.Background() + sc := &LDScopedClient{} + + newCtx := GoContextWithScopedClient(origCtx, sc) + retrieved, ok := GetScopedClient(newCtx) + + assert.True(t, ok, "expected to find scoped client in context") + assert.Equal(t, sc, retrieved, "retrieved client should match original") + }) + + t.Run("returns nil when not present", func(t *testing.T) { + retrieved, ok := GetScopedClient(context.Background()) + assert.False(t, ok, "should not find scoped client in empty context") + assert.Nil(t, retrieved, "retrieved client should be nil when not present") + }) +} + +func TestMustGetScopedClient(t *testing.T) { + sc := &LDScopedClient{} + ctxWith := GoContextWithScopedClient(context.Background(), sc) + + // Should return the client without panicking when present + assert.Equal(t, sc, MustGetScopedClient(ctxWith)) + + // Should panic when the client is not present + assert.Panics(t, func() { + MustGetScopedClient(context.Background()) + }) +}