Skip to content

Commit c494b0d

Browse files
committed
Implement suggested comments
1 parent 73f5f0e commit c494b0d

File tree

3 files changed

+105
-36
lines changed

3 files changed

+105
-36
lines changed

optimizely/decision_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# limitations under the License.
1313

1414
from __future__ import annotations
15-
from typing import TYPE_CHECKING, NamedTuple, Optional, Sequence, List, TypedDict
15+
from typing import TYPE_CHECKING, NamedTuple, Optional, Sequence, List, TypedDict, Union
1616

17-
from optimizely.helpers.types import HoldoutDict
17+
from optimizely.helpers.types import HoldoutDict, VariationDict
1818

1919
from . import bucketer
2020
from . import entities
@@ -61,7 +61,7 @@ class VariationResult(TypedDict):
6161
cmab_uuid: Optional[str]
6262
error: bool
6363
reasons: List[str]
64-
variation: Optional[entities.Variation]
64+
variation: Optional[Union[entities.Variation, VariationDict]]
6565

6666

6767
class DecisionResult(TypedDict):
@@ -82,7 +82,7 @@ class Decision(NamedTuple):
8282
"""Named tuple containing selected experiment, variation, source and cmab_uuid.
8383
None if no experiment/variation was selected."""
8484
experiment: Optional[entities.Experiment]
85-
variation: Optional[entities.Variation]
85+
variation: Optional[Union[entities.Variation, VariationDict]]
8686
source: Optional[str]
8787
cmab_uuid: Optional[str]
8888

@@ -953,7 +953,7 @@ def get_variations_for_feature_list(
953953
if feature.experimentIds:
954954
for experiment_id in feature.experimentIds:
955955
experiment = project_config.get_experiment_from_id(experiment_id)
956-
decision_variation = None
956+
decision_variation: Optional[Union[entities.Variation, VariationDict]] = None
957957

958958
if experiment:
959959
optimizely_decision_context = OptimizelyUserContext.OptimizelyDecisionContext(

optimizely/optimizely.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from __future__ import annotations
1515

16-
from typing import TYPE_CHECKING, Any, Optional
16+
from typing import TYPE_CHECKING, Any, Optional, Union
17+
18+
from optimizely.helpers.types import VariationDict
1719

1820

1921
from . import decision_service
@@ -198,6 +200,67 @@ def __init__(
198200
self.decision_service = decision_service.DecisionService(self.logger, user_profile_service, self.cmab_service)
199201
self.user_profile_service = user_profile_service
200202

203+
def _get_variation_key(self, variation: Optional[Union[entities.Variation, VariationDict]]) -> Optional[str]:
204+
"""Helper to extract variation key from either dict (holdout) or Variation object.
205+
Args:
206+
variation: Either a dict (from holdout) or entities.Variation object
207+
Returns:
208+
The variation key as a string, or None if not available
209+
"""
210+
if variation is None:
211+
return None
212+
213+
try:
214+
# Try dict access first (for holdouts)
215+
if isinstance(variation, dict):
216+
return variation.get('key')
217+
# Otherwise assume it's a Variation entity object
218+
else:
219+
return variation.key
220+
except (AttributeError, KeyError, TypeError):
221+
self.logger.warning(f"Unable to extract variation key from {type(variation)}")
222+
return None
223+
224+
def _get_variation_id(self, variation: Optional[Union[entities.Variation, VariationDict]]) -> Optional[str]:
225+
"""Helper to extract variation id from either dict (holdout) or Variation object.
226+
Args:
227+
variation: Either a dict (from holdout) or entities.Variation object
228+
Returns:
229+
The variation id as a string, or None if not available
230+
"""
231+
if variation is None:
232+
return None
233+
234+
try:
235+
# Try dict access first (for holdouts)
236+
if isinstance(variation, dict):
237+
return variation.get('id')
238+
# Otherwise assume it's a Variation entity object
239+
else:
240+
return variation.id
241+
except (AttributeError, KeyError, TypeError):
242+
self.logger.warning(f"Unable to extract variation id from {type(variation)}")
243+
return None
244+
245+
def _get_feature_enabled(self, variation: Optional[Union[entities.Variation, VariationDict]]) -> bool:
246+
"""Helper to extract featureEnabled flag from either dict (holdout) or Variation object.
247+
Args:
248+
variation: Either a dict (from holdout) or entities.Variation object
249+
Returns:
250+
The featureEnabled value, defaults to False if not available
251+
"""
252+
if variation is None:
253+
return False
254+
255+
try:
256+
if isinstance(variation, dict):
257+
feature_enabled = variation.get('featureEnabled', False)
258+
return bool(feature_enabled) if feature_enabled is not None else False
259+
else:
260+
return variation.featureEnabled if hasattr(variation, 'featureEnabled') else False
261+
except (AttributeError, KeyError, TypeError):
262+
return False
263+
201264
def _validate_instantiation_options(self) -> None:
202265
""" Helper method to validate all instantiation parameters.
203266
@@ -267,7 +330,7 @@ def _validate_user_inputs(
267330

268331
def _send_impression_event(
269332
self, project_config: project_config.ProjectConfig, experiment: Optional[entities.Experiment],
270-
variation: Optional[entities.Variation], flag_key: str, rule_key: str, rule_type: str,
333+
variation: Optional[Union[entities.Variation, VariationDict]], flag_key: str, rule_key: str, rule_type: str,
271334
enabled: bool, user_id: str, attributes: Optional[UserAttributes], cmab_uuid: Optional[str] = None
272335
) -> None:
273336
""" Helper method to send impression event.
@@ -286,7 +349,7 @@ def _send_impression_event(
286349
if not experiment:
287350
experiment = entities.Experiment.get_default()
288351

289-
variation_id = variation.id if variation is not None else None
352+
variation_id = self._get_variation_id(variation) if variation is not None else None
290353
user_event = user_event_factory.UserEventFactory.create_impression_event(
291354
project_config, experiment, variation_id,
292355
flag_key, rule_key, rule_type,
@@ -372,7 +435,7 @@ def _get_feature_variable_for_type(
372435

373436
if decision.variation:
374437

375-
feature_enabled = decision.variation.featureEnabled
438+
feature_enabled = self._get_feature_enabled(decision.variation)
376439
if feature_enabled:
377440
variable_value = project_config.get_variable_value_for_variation(variable, decision.variation)
378441
self.logger.info(
@@ -393,7 +456,7 @@ def _get_feature_variable_for_type(
393456
if decision.source == enums.DecisionSources.FEATURE_TEST:
394457
source_info = {
395458
'experiment_key': decision.experiment.key if decision.experiment else None,
396-
'variation_key': decision.variation.key if decision.variation else None,
459+
'variation_key': self._get_variation_key(decision.variation),
397460
}
398461

399462
try:
@@ -461,7 +524,7 @@ def _get_all_feature_variables_for_type(
461524

462525
if decision.variation:
463526

464-
feature_enabled = decision.variation.featureEnabled
527+
feature_enabled = self._get_feature_enabled(decision.variation)
465528
if feature_enabled:
466529
self.logger.info(
467530
f'Feature "{feature_key}" is enabled for user "{user_id}".'
@@ -497,7 +560,7 @@ def _get_all_feature_variables_for_type(
497560
if decision.source == enums.DecisionSources.FEATURE_TEST:
498561
source_info = {
499562
'experiment_key': decision.experiment.key if decision.experiment else None,
500-
'variation_key': decision.variation.key if decision.variation else None,
563+
'variation_key': self._get_variation_key(decision.variation),
501564
}
502565

503566
self.notification_center.send_notifications(
@@ -670,7 +733,7 @@ def get_variation(
670733
variation = variation_result['variation']
671734
user_profile_tracker.save_user_profile()
672735
if variation:
673-
variation_key = variation.key
736+
variation_key = self._get_variation_key(variation)
674737

675738
if project_config.is_feature_experiment(experiment.id):
676739
decision_notification_type = enums.DecisionNotificationTypes.FEATURE_TEST
@@ -735,7 +798,7 @@ def is_feature_enabled(self, feature_key: str, user_id: str, attributes: Optiona
735798
is_source_rollout = decision.source == enums.DecisionSources.ROLLOUT
736799

737800
if decision.variation:
738-
if decision.variation.featureEnabled is True:
801+
if self._get_feature_enabled(decision.variation) is True:
739802
feature_enabled = True
740803

741804
if (is_source_rollout or not decision.variation) and project_config.get_send_flag_decisions_value():
@@ -748,7 +811,7 @@ def is_feature_enabled(self, feature_key: str, user_id: str, attributes: Optiona
748811
if is_source_experiment and decision.variation and decision.experiment:
749812
source_info = {
750813
'experiment_key': decision.experiment.key,
751-
'variation_key': decision.variation.key,
814+
'variation_key': self._get_variation_key(decision.variation),
752815
}
753816
self._send_impression_event(
754817
project_config, decision.experiment, decision.variation, feature.key, decision.experiment.key,
@@ -1182,7 +1245,7 @@ def _create_optimizely_decision(
11821245
user_id = user_context.user_id
11831246
feature_enabled = False
11841247
if flag_decision.variation is not None:
1185-
if flag_decision.variation.featureEnabled:
1248+
if self._get_feature_enabled(flag_decision.variation):
11861249
feature_enabled = True
11871250

11881251
self.logger.info(f'Feature {flag_key} is enabled for user {user_id} {feature_enabled}"')
@@ -1231,11 +1294,7 @@ def _create_optimizely_decision(
12311294
all_variables[variable_key] = actual_value
12321295

12331296
should_include_reasons = OptimizelyDecideOption.INCLUDE_REASONS in decide_options
1234-
variation_key = (
1235-
flag_decision.variation.key
1236-
if flag_decision is not None and flag_decision.variation is not None
1237-
else None
1238-
)
1297+
variation_key = self._get_variation_key(flag_decision.variation)
12391298

12401299
experiment_id = None
12411300
variation_id = None
@@ -1248,7 +1307,7 @@ def _create_optimizely_decision(
12481307

12491308
try:
12501309
if flag_decision.variation is not None:
1251-
variation_id = flag_decision.variation.id
1310+
variation_id = self._get_variation_id(flag_decision.variation)
12521311
except AttributeError:
12531312
self.logger.warning("flag_decision.variation has no attribute 'id'")
12541313

optimizely/project_config.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313
from __future__ import annotations
1414
import json
15-
from typing import TYPE_CHECKING, Optional, Type, TypeVar, cast, Any, Iterable, List
15+
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union, cast, Any, Iterable, List
1616
from sys import version_info
1717

1818
from . import entities
@@ -21,6 +21,8 @@
2121
from .helpers import enums
2222
from .helpers import types
2323

24+
from optimizely.helpers.types import HoldoutDict, VariationDict
25+
2426
if version_info < (3, 8):
2527
from typing_extensions import Final
2628
else:
@@ -88,12 +90,12 @@ def __init__(self, datafile: str | bytes, logger: Logger, error_handler: Any):
8890
region_value = config.get('region')
8991
self.region: str = region_value or 'US'
9092

91-
self.holdouts: list[dict[str, Any]] = config.get('holdouts', [])
92-
self.holdout_id_map: dict[str, dict[str, Any]] = {}
93-
self.global_holdouts: dict[str, dict[str, Any]] = {}
94-
self.included_holdouts: dict[str, list[dict[str, Any]]] = {}
95-
self.excluded_holdouts: dict[str, list[dict[str, Any]]] = {}
96-
self.flag_holdouts_map: dict[str, list[dict[str, Any]]] = {}
93+
self.holdouts: list[HoldoutDict] = config.get('holdouts', [])
94+
self.holdout_id_map: dict[str, HoldoutDict] = {}
95+
self.global_holdouts: dict[str, HoldoutDict] = {}
96+
self.included_holdouts: dict[str, list[HoldoutDict]] = {}
97+
self.excluded_holdouts: dict[str, list[HoldoutDict]] = {}
98+
self.flag_holdouts_map: dict[str, list[HoldoutDict]] = {}
9799

98100
for holdout in self.holdouts:
99101
if holdout.get('status') != 'Running':
@@ -654,7 +656,7 @@ def get_rollout_from_id(self, rollout_id: str) -> Optional[entities.Layer]:
654656
return None
655657

656658
def get_variable_value_for_variation(
657-
self, variable: Optional[entities.Variable], variation: Optional[entities.Variation]
659+
self, variable: Optional[entities.Variable], variation: Optional[Union[entities.Variation, VariationDict]]
658660
) -> Optional[str]:
659661
""" Get the variable value for the given variation.
660662
@@ -668,12 +670,21 @@ def get_variable_value_for_variation(
668670

669671
if not variable or not variation:
670672
return None
671-
if variation.id not in self.variation_variable_usage_map:
672-
self.logger.error(f'Variation with ID "{variation.id}" is not in the datafile.')
673+
674+
# Extract variation ID from either Variation entity or dict
675+
if isinstance(variation, dict):
676+
variation_id = variation.get('id')
677+
if not variation_id:
678+
return None
679+
else:
680+
variation_id = variation.id
681+
682+
if variation_id not in self.variation_variable_usage_map:
683+
self.logger.error(f'Variation with ID "{variation_id}" is not in the datafile.')
673684
return None
674685

675686
# Get all variable usages for the given variation
676-
variable_usages = self.variation_variable_usage_map[variation.id]
687+
variable_usages = self.variation_variable_usage_map[variation_id]
677688

678689
# Find usage in given variation
679690
variable_usage = None
@@ -682,7 +693,6 @@ def get_variable_value_for_variation(
682693

683694
if variable_usage:
684695
variable_value = variable_usage.value
685-
686696
else:
687697
variable_value = variable.defaultValue
688698

@@ -824,7 +834,7 @@ def get_flag_variation(
824834

825835
return None
826836

827-
def get_holdouts_for_flag(self, flag_key: str) -> list[Any]:
837+
def get_holdouts_for_flag(self, flag_key: str) -> list[HoldoutDict]:
828838
""" Helper method to get holdouts from an applied feature flag.
829839
830840
Args:
@@ -838,7 +848,7 @@ def get_holdouts_for_flag(self, flag_key: str) -> list[Any]:
838848

839849
return self.flag_holdouts_map.get(flag_key, [])
840850

841-
def get_holdout(self, holdout_id: str) -> Optional[dict[str, Any]]:
851+
def get_holdout(self, holdout_id: str) -> Optional[HoldoutDict]:
842852
""" Helper method to get holdout from holdout ID.
843853
844854
Args:

0 commit comments

Comments
 (0)