diff --git a/ldclient/impl/datasourcev2/polling.py b/ldclient/impl/datasourcev2/polling.py index 8a350c82..8f867097 100644 --- a/ldclient/impl/datasourcev2/polling.py +++ b/ldclient/impl/datasourcev2/polling.py @@ -14,7 +14,7 @@ import urllib3 from ldclient.config import Config -from ldclient.impl.datasystem import BasisResult, Update +from ldclient.impl.datasystem import BasisResult, SelectorStore, Update from ldclient.impl.datasystem.protocolv2 import ( Basis, ChangeSet, @@ -96,13 +96,13 @@ def name(self) -> str: """Returns the name of the initializer.""" return "PollingDataSourceV2" - def fetch(self) -> BasisResult: + def fetch(self, ss: SelectorStore) -> BasisResult: """ Fetch returns a Basis, or an error if the Basis could not be retrieved. """ - return self._poll() + return self._poll(ss) - def sync(self) -> Generator[Update, None, None]: + def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: """ sync begins the synchronization process for the data source, yielding Update objects until the connection is closed or an unrecoverable error @@ -111,7 +111,7 @@ def sync(self) -> Generator[Update, None, None]: log.info("Starting PollingDataSourceV2 synchronizer") self._stop.clear() while self._stop.is_set() is False: - result = self._requester.fetch(None) + result = self._requester.fetch(ss.selector()) if isinstance(result, _Fail): if isinstance(result.exception, UnsuccessfulResponseException): error_info = DataSourceErrorInfo( @@ -170,10 +170,9 @@ def stop(self): self._task.stop() self._stop.set() - def _poll(self) -> BasisResult: + def _poll(self, ss: SelectorStore) -> BasisResult: try: - # TODO(fdv2): Need to pass the selector through - result = self._requester.fetch(None) + result = self._requester.fetch(ss.selector()) if isinstance(result, _Fail): if isinstance(result.exception, UnsuccessfulResponseException): diff --git a/ldclient/impl/datasourcev2/streaming.py b/ldclient/impl/datasourcev2/streaming.py index 75e44552..0f6590dc 100644 --- a/ldclient/impl/datasourcev2/streaming.py +++ b/ldclient/impl/datasourcev2/streaming.py @@ -19,7 +19,7 @@ from ld_eventsource.errors import HTTPStatusError from ldclient.config import Config -from ldclient.impl.datasystem import Synchronizer, Update +from ldclient.impl.datasystem import SelectorStore, Synchronizer, Update from ldclient.impl.datasystem.protocolv2 import ( ChangeSetBuilder, DeleteObject, @@ -54,12 +54,10 @@ STREAMING_ENDPOINT = "/sdk/stream" -SseClientBuilder = Callable[[Config], SSEClient] +SseClientBuilder = Callable[[Config, SelectorStore], SSEClient] -# TODO(sdk-1391): Pass a selector-retrieving function through so it can -# re-connect with the last known status. -def create_sse_client(config: Config) -> SSEClient: +def create_sse_client(config: Config, ss: SelectorStore) -> SSEClient: """ " create_sse_client creates an SSEClient instance configured to connect to the LaunchDarkly streaming endpoint. @@ -76,12 +74,17 @@ def create_sse_client(config: Config) -> SSEClient: override_read_timeout=STREAM_READ_TIMEOUT, ) + def query_params() -> dict[str, str]: + selector = ss.selector() + return {"basis": selector.state} if selector.is_defined() else {} + return SSEClient( connect=ConnectStrategy.http( url=uri, headers=http_factory.base_headers, pool=stream_http_factory.create_pool_manager(1, uri), urllib3_request_options={"timeout": stream_http_factory.timeout}, + query_params=query_params ), # we'll make error-handling decisions when we see a Fault error_strategy=ErrorStrategy.always_continue(), @@ -118,13 +121,13 @@ def name(self) -> str: """ return "streaming" - def sync(self) -> Generator[Update, None, None]: + def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: """ sync should begin the synchronization process for the data source, yielding Update objects until the connection is closed or an unrecoverable error occurs. """ - self._sse = self._sse_client_builder(self._config) + self._sse = self._sse_client_builder(self._config, ss) if self._sse is None: log.error("Failed to create SSE client for streaming updates.") return diff --git a/ldclient/impl/datasystem/__init__.py b/ldclient/impl/datasystem/__init__.py index cc6fbba5..57131c87 100644 --- a/ldclient/impl/datasystem/__init__.py +++ b/ldclient/impl/datasystem/__init__.py @@ -7,9 +7,9 @@ from dataclasses import dataclass from enum import Enum from threading import Event -from typing import Generator, Optional, Protocol +from typing import Callable, Generator, Optional, Protocol -from ldclient.impl.datasystem.protocolv2 import Basis, ChangeSet +from ldclient.impl.datasystem.protocolv2 import Basis, ChangeSet, Selector from ldclient.impl.util import _Result from ldclient.interfaces import ( DataSourceErrorInfo, @@ -142,6 +142,21 @@ def target_availability(self) -> DataAvailability: raise NotImplementedError +class SelectorStore(Protocol): + """ + SelectorStore represents a component capable of providing Selectors + for data retrieval. + """ + + @abstractmethod + def selector(self) -> Selector: + """ + get_selector should return a Selector object that defines the criteria + for data retrieval. + """ + raise NotImplementedError + + BasisResult = _Result[Basis, str] @@ -165,10 +180,12 @@ def name(self) -> str: raise NotImplementedError @abstractmethod - def fetch(self) -> BasisResult: + def fetch(self, ss: SelectorStore) -> BasisResult: """ fetch should retrieve the initial data set for the data source, returning a Basis object on success, or an error message on failure. + + :param ss: A SelectorStore that provides the Selector to use as a basis for data retrieval. """ raise NotImplementedError @@ -205,11 +222,13 @@ def name(self) -> str: raise NotImplementedError @abstractmethod - def sync(self) -> Generator[Update, None, None]: + def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: """ sync should begin the synchronization process for the data source, yielding Update objects until the connection is closed or an unrecoverable error occurs. + + :param ss: A SelectorStore that provides the Selector to use as a basis for data retrieval. """ raise NotImplementedError diff --git a/ldclient/impl/datasystem/fdv2.py b/ldclient/impl/datasystem/fdv2.py index 01824203..8dd8e5c7 100644 --- a/ldclient/impl/datasystem/fdv2.py +++ b/ldclient/impl/datasystem/fdv2.py @@ -299,7 +299,7 @@ def _run_initializers(self, set_on_ready: Event): initializer = initializer_builder(self._config) log.info("Attempting to initialize via %s", initializer.name) - basis_result = initializer.fetch() + basis_result = initializer.fetch(self._store) if isinstance(basis_result, _Fail): log.warning("Initializer %s failed: %s", initializer.name, basis_result.error) @@ -426,7 +426,7 @@ def _consume_synchronizer_results( :return: Tuple of (should_remove_sync, fallback_to_fdv1) """ try: - for update in synchronizer.sync(): + for update in synchronizer.sync(self._store): log.info("Synchronizer %s update: %s", synchronizer.name, update.state) if self._stop_event.is_set(): return False, False diff --git a/ldclient/impl/integrations/test_datav2/test_data_sourcev2.py b/ldclient/impl/integrations/test_datav2/test_data_sourcev2.py index bf3397c3..6d8edacc 100644 --- a/ldclient/impl/integrations/test_datav2/test_data_sourcev2.py +++ b/ldclient/impl/integrations/test_datav2/test_data_sourcev2.py @@ -2,7 +2,7 @@ from queue import Empty, Queue from typing import Generator -from ldclient.impl.datasystem import BasisResult, Update +from ldclient.impl.datasystem import BasisResult, SelectorStore, Update from ldclient.impl.datasystem.protocolv2 import ( Basis, ChangeSetBuilder, @@ -16,6 +16,7 @@ DataSourceErrorKind, DataSourceState ) +from ldclient.testing.mock_components import MockSelectorStore class _TestDataSourceV2: @@ -47,7 +48,7 @@ def name(self) -> str: """Return the name of this data source.""" return "TestDataV2" - def fetch(self) -> BasisResult: + def fetch(self, ss: SelectorStore) -> BasisResult: """ Implementation of the Initializer.fetch method. @@ -90,7 +91,7 @@ def fetch(self) -> BasisResult: except Exception as e: return _Fail(f"Error fetching test data: {str(e)}") - def sync(self) -> Generator[Update, None, None]: + def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: """ Implementation of the Synchronizer.sync method. @@ -98,7 +99,7 @@ def sync(self) -> Generator[Update, None, None]: """ # First yield initial data - initial_result = self.fetch() + initial_result = self.fetch(ss) if isinstance(initial_result, _Fail): yield Update( state=DataSourceState.OFF, @@ -143,8 +144,8 @@ def sync(self) -> Generator[Update, None, None]: ) break - def close(self): - """Close the data source and clean up resources.""" + def stop(self): + """Stop the data source and clean up resources""" with self._lock: if self._closed: return diff --git a/ldclient/testing/impl/datasourcev2/test_polling_initializer.py b/ldclient/testing/impl/datasourcev2/test_polling_initializer.py index 0a7079d6..5e5e084f 100644 --- a/ldclient/testing/impl/datasourcev2/test_polling_initializer.py +++ b/ldclient/testing/impl/datasourcev2/test_polling_initializer.py @@ -11,6 +11,7 @@ ) from ldclient.impl.datasystem.protocolv2 import ChangeSetBuilder, IntentCode from ldclient.impl.util import UnsuccessfulResponseException, _Fail, _Success +from ldclient.testing.mock_components import MockSelectorStore class MockExceptionThrowingPollingRequester: # pylint: disable=too-few-public-methods @@ -37,7 +38,7 @@ def test_error_is_returned_on_failure(): mock_requester = MockPollingRequester(_Fail(error="failure message")) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Fail) assert result.error == "failure message" @@ -50,7 +51,7 @@ def test_error_is_recoverable(): ) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Fail) assert result.error is not None @@ -64,7 +65,7 @@ def test_error_is_unrecoverable(): ) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Fail) assert result.error is not None @@ -78,7 +79,7 @@ def test_handles_transfer_none(): ) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Success) assert result.value is not None @@ -92,7 +93,7 @@ def test_handles_uncaught_exception(): mock_requester = MockExceptionThrowingPollingRequester() ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Fail) assert result.error is not None @@ -111,7 +112,7 @@ def test_handles_transfer_full(): mock_requester = MockPollingRequester(_Success(value=(change_set_result.value, {}))) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Success) assert result.value is not None @@ -129,7 +130,7 @@ def test_handles_transfer_changes(): mock_requester = MockPollingRequester(_Success(value=(change_set_result.value, {}))) ds = PollingDataSource(poll_interval=1.0, requester=mock_requester) - result = ds.fetch() + result = ds.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Success) assert result.value is not None diff --git a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py index 92391368..3410a1e6 100644 --- a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py @@ -22,6 +22,7 @@ ) from ldclient.impl.util import UnsuccessfulResponseException, _Fail, _Success from ldclient.interfaces import DataSourceErrorKind, DataSourceState +from ldclient.testing.mock_components import MockSelectorStore class ListBasedRequester: @@ -103,7 +104,7 @@ def test_handles_no_changes(): poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) ) - valid = next(synchronizer.sync()) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert valid.state == DataSourceState.VALID assert valid.error is None @@ -124,7 +125,7 @@ def test_handles_empty_changeset(): synchronizer = PollingDataSource( poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) ) - valid = next(synchronizer.sync()) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert valid.state == DataSourceState.VALID assert valid.error is None @@ -152,7 +153,7 @@ def test_handles_put_objects(): synchronizer = PollingDataSource( poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) ) - valid = next(synchronizer.sync()) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert valid.state == DataSourceState.VALID assert valid.error is None @@ -183,7 +184,7 @@ def test_handles_delete_objects(): synchronizer = PollingDataSource( poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) ) - valid = next(synchronizer.sync()) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert valid.state == DataSourceState.VALID assert valid.error is None @@ -216,7 +217,7 @@ def test_generic_error_interrupts_and_recovers(): results=iter([_Fail(error="error for test"), polling_result]) ), ) - sync = synchronizer.sync() + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) interrupted = next(sync) valid = next(sync) @@ -250,7 +251,7 @@ def test_recoverable_error_continues(): poll_interval=0.01, requester=ListBasedRequester(results=iter([_failure, polling_result])), ) - sync = synchronizer.sync() + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) interrupted = next(sync) valid = next(sync) @@ -288,7 +289,7 @@ def test_unrecoverable_error_shuts_down(): poll_interval=0.01, requester=ListBasedRequester(results=iter([_failure, polling_result])), ) - sync = synchronizer.sync() + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) off = next(sync) assert off.state == DataSourceState.OFF assert off.error is not None diff --git a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py index d78aac6c..f749bff8 100644 --- a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py @@ -16,6 +16,7 @@ SseClientBuilder, StreamingDataSource ) +from ldclient.impl.datasystem import SelectorStore from ldclient.impl.datasystem.protocolv2 import ( ChangeType, DeleteObject, @@ -30,12 +31,13 @@ ServerIntent ) from ldclient.interfaces import DataSourceErrorKind, DataSourceState +from ldclient.testing.mock_components import MockSelectorStore def list_sse_client( events: Iterable[Action], # pylint: disable=redefined-outer-name ) -> SseClientBuilder: - def builder(_: Config) -> SSEClient: + def builder(config: Config, ss: SelectorStore) -> SSEClient: return ListBasedSseClient(events) return builder @@ -83,7 +85,7 @@ class UnknownTypeOfEvent(Action): synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = list_sse_client([UnknownTypeOfEvent(), unknown_named_event]) - assert len(list(synchronizer.sync())) == 0 + assert len(list(synchronizer.sync(MockSelectorStore(Selector.no_selector())))) == 0 def test_ignores_faults_without_errors(): @@ -91,7 +93,7 @@ def test_ignores_faults_without_errors(): synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = list_sse_client([errorless_fault]) - assert len(list(synchronizer.sync())) == 0 + assert len(list(synchronizer.sync(MockSelectorStore(Selector.no_selector())))) == 0 @pytest.fixture @@ -169,7 +171,7 @@ def test_handles_no_changes(): synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = list_sse_client([intent_event]) - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -189,7 +191,7 @@ def test_handles_empty_changeset(events): # pylint: disable=redefined-outer-nam synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -216,7 +218,7 @@ def test_handles_put_objects(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -248,7 +250,7 @@ def test_handles_delete_objects(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -279,7 +281,7 @@ def test_swallows_goodbye(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -306,7 +308,7 @@ def test_swallows_heartbeat(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -335,7 +337,7 @@ def test_error_resets(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.VALID @@ -359,7 +361,7 @@ def test_handles_out_of_order(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.INTERRUPTED @@ -390,7 +392,7 @@ def test_invalid_json_decoding(events): # pylint: disable=redefined-outer-name synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 2 assert updates[0].state == DataSourceState.INTERRUPTED @@ -423,7 +425,7 @@ def test_stops_on_unrecoverable_status_code( synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 1 assert updates[0].state == DataSourceState.OFF @@ -453,7 +455,7 @@ def test_continues_on_recoverable_status_code( ) synchronizer = StreamingDataSource(Config(sdk_key="key")) synchronizer._sse_client_builder = builder - updates = list(synchronizer.sync()) + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) assert len(updates) == 3 assert updates[0].state == DataSourceState.INTERRUPTED diff --git a/ldclient/testing/integrations/test_test_data_sourcev2.py b/ldclient/testing/integrations/test_test_data_sourcev2.py index 0660ffae..e0ff825d 100644 --- a/ldclient/testing/integrations/test_test_data_sourcev2.py +++ b/ldclient/testing/integrations/test_test_data_sourcev2.py @@ -8,11 +8,13 @@ from ldclient.impl.datasystem.protocolv2 import ( ChangeType, IntentCode, - ObjectKind + ObjectKind, + Selector ) from ldclient.impl.util import _Fail, _Success from ldclient.integrations.test_datav2 import FlagBuilderV2, TestDataV2 from ldclient.interfaces import DataSourceState +from ldclient.testing.mock_components import MockSelectorStore # Test Data + Data Source V2 @@ -22,7 +24,7 @@ def test_creates_valid_initializer(): td = TestDataV2.data_source() initializer = td.build_initializer(Config(sdk_key="dummy")) - result = initializer.fetch() + result = initializer.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Success) basis = result.value @@ -42,7 +44,7 @@ def test_creates_valid_synchronizer(): def collect_updates(): nonlocal update_count - for update in synchronizer.sync(): + for update in synchronizer.sync(MockSelectorStore(Selector.no_selector())): updates.append(update) update_count += 1 @@ -51,7 +53,7 @@ def collect_updates(): assert update.state == DataSourceState.VALID assert update.change_set is not None assert update.change_set.intent_code == IntentCode.TRANSFER_FULL - synchronizer.close() + synchronizer.stop() break # Start the synchronizer in a thread with timeout to prevent hanging @@ -63,7 +65,7 @@ def collect_updates(): # Ensure thread completed successfully if sync_thread.is_alive(): - synchronizer.close() + synchronizer.stop() sync_thread.join() pytest.fail("Synchronizer test timed out after 5 seconds") @@ -240,7 +242,7 @@ def test_initializer_fetches_flag_data(): td.update(td.flag('some-flag').variation_for_all(True)) initializer = td.build_initializer(Config(sdk_key="dummy")) - result = initializer.fetch() + result = initializer.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Success) basis = result.value @@ -261,7 +263,7 @@ def test_synchronizer_yields_initial_data(): synchronizer = td.build_synchronizer(Config(sdk_key="dummy")) - update_iter = iter(synchronizer.sync()) + update_iter = iter(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) initial_update = next(update_iter) assert initial_update.state == DataSourceState.VALID @@ -272,7 +274,7 @@ def test_synchronizer_yields_initial_data(): change = initial_update.change_set.changes[0] assert change.key == 'initial-flag' - synchronizer.close() + synchronizer.stop() def test_synchronizer_receives_updates(): @@ -285,12 +287,12 @@ def test_synchronizer_receives_updates(): def collect_updates(): nonlocal update_count - for update in synchronizer.sync(): + for update in synchronizer.sync(MockSelectorStore(Selector.no_selector())): updates.append(update) update_count += 1 if update_count >= 2: - synchronizer.close() + synchronizer.stop() break # Start the synchronizer in a thread @@ -329,17 +331,17 @@ def test_multiple_synchronizers_receive_updates(): updates2 = [] def collect_updates_1(): - for update in sync1.sync(): + for update in sync1.sync(MockSelectorStore(Selector.no_selector())): updates1.append(update) if len(updates1) >= 2: - sync1.close() + sync1.stop() break def collect_updates_2(): - for update in sync2.sync(): + for update in sync2.sync(MockSelectorStore(Selector.no_selector())): updates2.append(update) if len(updates2) >= 2: - sync2.close() + sync2.stop() break # Start both synchronizers @@ -373,9 +375,9 @@ def test_closed_synchronizer_stops_yielding(): updates = [] # Get initial update then close - for update in synchronizer.sync(): + for update in synchronizer.sync(MockSelectorStore(Selector.no_selector())): updates.append(update) - synchronizer.close() + synchronizer.stop() break assert len(updates) == 1 @@ -385,7 +387,7 @@ def test_closed_synchronizer_stops_yielding(): # Try to get more updates - should get an error state indicating closure additional_updates = [] - for update in synchronizer.sync(): + for update in synchronizer.sync(MockSelectorStore(Selector.no_selector())): additional_updates.append(update) break @@ -401,11 +403,12 @@ def test_initializer_can_sync(): td.update(td.flag('test-flag').variation_for_all(True)) initializer = td.build_initializer(Config(sdk_key="dummy")) - sync_gen = initializer.sync() + sync_gen = initializer.sync(MockSelectorStore(Selector.no_selector())) # Should get initial update with data initial_update = next(sync_gen) assert initial_update.state == DataSourceState.VALID + assert initial_update.change_set is not None assert initial_update.change_set.intent_code == IntentCode.TRANSFER_FULL assert len(initial_update.change_set.changes) == 1 assert initial_update.change_set.changes[0].key == 'test-flag' @@ -442,8 +445,8 @@ def test_error_handling_in_fetch(): initializer = td.build_initializer(Config(sdk_key="dummy")) # Close the initializer to trigger error condition - initializer.close() + initializer.stop() - result = initializer.fetch() + result = initializer.fetch(MockSelectorStore(Selector.no_selector())) assert isinstance(result, _Fail) assert "TestDataV2 source has been closed" in result.error diff --git a/ldclient/testing/mock_components.py b/ldclient/testing/mock_components.py index 44d3f78a..f1b20235 100644 --- a/ldclient/testing/mock_components.py +++ b/ldclient/testing/mock_components.py @@ -1,6 +1,7 @@ import time from typing import Callable +from ldclient.impl.datasystem.protocolv2 import Selector from ldclient.interfaces import BigSegmentStore, BigSegmentStoreMetadata @@ -42,3 +43,11 @@ def membership_queries(self) -> list: def __fail(self): raise Exception("deliberate error") + + +class MockSelectorStore(): + def __init__(self, selector: Selector): + self._selector = selector + + def selector(self) -> Selector: + return self._selector