diff --git a/src/retriever/data_tiers/tier_1/elasticsearch/driver.py b/src/retriever/data_tiers/tier_1/elasticsearch/driver.py index 4c38bb98..ad43f493 100644 --- a/src/retriever/data_tiers/tier_1/elasticsearch/driver.py +++ b/src/retriever/data_tiers/tier_1/elasticsearch/driver.py @@ -250,10 +250,11 @@ async def legacy_get_operations( return operations, nodes @override - async def get_metadata(self) -> dict[str, Any] | None: + async def get_metadata(self, bypass_cache: bool = False) -> dict[str, Any] | None: return await get_t1_metadata( es_connection=self.es_connection, indices_alias=CONFIG.tier1.elasticsearch.index_name, + bypass_cache=bypass_cache, ) @override @@ -263,11 +264,11 @@ async def get_operations( # return await self.legacy_get_metadata() return await self.get_t1_operations() - async def get_t1_operations( - self, - ) -> tuple[list[Operation], dict[BiolinkEntity, OperationNode]]: - """Get tier1 operations based on metadata.""" - metadata_blob = await self.get_metadata() + async def get_valid_metadata( + self, bypass_cache: bool = False + ) -> tuple[dict[str, Any], list[str]]: + """Get valid metadata and raise exception if failed.""" + metadata_blob = await self.get_metadata(bypass_cache) if metadata_blob is None: raise ValueError( @@ -280,8 +281,24 @@ async def get_t1_operations( indices = await get_t1_indices(self.es_connection) - metadata_list = extract_metadata_entries_from_blob(metadata_blob, indices) + # ensure metadata matches indices + mismatched = any(metadata_blob.get(i) is None for i in indices) + + if mismatched: + if not bypass_cache: + log.error("Possibly stale data got from cache. Refetching remotely.") + return await self.get_valid_metadata(bypass_cache=True) + else: + raise ValueError("Invalid metadata retrieved.") + return metadata_blob, indices + + async def get_t1_operations( + self, + ) -> tuple[list[Operation], dict[BiolinkEntity, OperationNode]]: + """Get tier1 operations based on metadata.""" + metadata_blob, indices = await self.get_valid_metadata() + metadata_list = extract_metadata_entries_from_blob(metadata_blob, indices) operations, nodes = await generate_operations(metadata_list) return operations, nodes diff --git a/src/retriever/data_tiers/tier_1/elasticsearch/meta.py b/src/retriever/data_tiers/tier_1/elasticsearch/meta.py index c2897cad..ac843344 100644 --- a/src/retriever/data_tiers/tier_1/elasticsearch/meta.py +++ b/src/retriever/data_tiers/tier_1/elasticsearch/meta.py @@ -115,10 +115,13 @@ async def retrieve_metadata_from_es( async def get_t1_metadata( - es_connection: AsyncElasticsearch | None, indices_alias: str, retries: int = 0 + es_connection: AsyncElasticsearch | None, + indices_alias: str, + bypass_cache: bool, + retries: int = 0, ) -> T1MetaData | None: """Caller to orchestrate retrieving t1 metadata.""" - meta_blob = await read_metadata_cache(CACHE_KEY) + meta_blob = None if bypass_cache else await read_metadata_cache(CACHE_KEY) if not meta_blob: try: if es_connection is None: @@ -133,7 +136,9 @@ async def get_t1_metadata( "Invalid Elasticsearch connection" ): return None - return await get_t1_metadata(es_connection, indices_alias, retries + 1) + return await get_t1_metadata( + es_connection, indices_alias, bypass_cache=True, retries=retries + 1 + ) log.success("DINGO Metadata retrieved!") return meta_blob diff --git a/tests/data_tiers/tier_1/elasticsearch_tests/test_tier1_driver.py b/tests/data_tiers/tier_1/elasticsearch_tests/test_tier1_driver.py index 80999deb..1a6813b4 100644 --- a/tests/data_tiers/tier_1/elasticsearch_tests/test_tier1_driver.py +++ b/tests/data_tiers/tier_1/elasticsearch_tests/test_tier1_driver.py @@ -91,12 +91,12 @@ def mock_elasticsearch_config(monkeypatch: pytest.MonkeyPatch) -> Iterator[None] @pytest.mark.parametrize( "payload, expected", [ - (PAYLOAD_0, 1), + (PAYLOAD_0, 0), (PAYLOAD_1, 4), (PAYLOAD_2, 32), ( [PAYLOAD_0, PAYLOAD_1, PAYLOAD_2], - [1, 4, 32] + [0, 4, 32] ) ], ids=[ @@ -117,6 +117,11 @@ async def test_elasticsearch_driver(payload: ESPayload | list[ESPayload], expect hits: list[ESEdge] | list[ESEdge] = await driver.run_query(payload) def assert_single_result(res, expected_result_num: int): + if res is None: + if expected_result_num != 0: + raise AssertionError(f"Expected empty result, got {type(res)}") + else: + return if not isinstance(res, list): raise AssertionError(f"Expected results to be list, got {type(res)}") if not len(res) == expected_result_num: @@ -201,7 +206,7 @@ async def test_metadata_retrieval(): "qgraph, expected_hits", [ (DINGO_QGRAPH, 8), - (ID_BYPASS_PAYLOAD, 6395), # <-- adjust to the real number + (ID_BYPASS_PAYLOAD, 6776), # <-- adjust to the real number ], ) async def test_end_to_end(qgraph, expected_hits): @@ -244,4 +249,4 @@ async def test_ubergraph_info_retrieval(): # print(k, v) assert "mapping" in info - assert len(info["mapping"]) == 581143 + assert len(info["mapping"]) == 122707