From 92e4c74cd50a955f9832067411c080f8f36d8ef5 Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Thu, 15 Aug 2024 16:25:27 -0400 Subject: [PATCH 1/5] WIP --- .../ragstack_knowledge_store/graph_store.py | 29 ++++++++++++++----- .../integration_tests/test_graph_store.py | 6 ++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 02110a62e..068d7ee0f 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -367,6 +367,7 @@ def mmr_traversal_search( lambda_mult: float = 0.5, score_threshold: float = float("-inf"), metadata_filter: dict[str, Any] = {}, # noqa: B006 + tag_filter: set[tuple[str, str]], ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -398,6 +399,7 @@ def mmr_traversal_search( score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. metadata_filter: Optional metadata to filter the results. + tag_filter: Optional tags to filter graph edges to be traversed. """ query_embedding = self._embedding.embed_query(query) helper = MmrHelper( @@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if tag_filter.len() == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding @@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None: for row in fetched: if row.content_id not in outgoing_tags: candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + if tag_filter.len() == 0: + outgoing_tags[row.content_id] = set(row.link_to_tags or []) + else: + outgoing_tags[row.content_id] = tag_filter.intersection(set(row.link_to_tags or [])) helper.add_candidates(candidates) if initial_roots: @@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if tag_filter.len() == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding ) diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 17d5a7a77..84b66abce 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -211,6 +211,12 @@ def test_mmr_traversal( results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"]) assert _result_ids(results) == ["v1", "v3", "v2"] + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("explicit", "link"))) + assert _result_ids(results) == ["v0", "v2"] + + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("no", "match"))) + assert _result_ids(results) == [] + def test_write_retrieve_keywords( graph_store_factory: Callable[[MetadataIndexingType], GraphStore], From 6ab3d7de5768e6d0ebd039793cef2a2f207b18d7 Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Thu, 15 Aug 2024 16:58:35 -0400 Subject: [PATCH 2/5] lint --- .../ragstack_knowledge_store/graph_store.py | 4 +++- .../tests/integration_tests/test_graph_store.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 068d7ee0f..047984567 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -484,7 +484,9 @@ def fetch_initial_candidates() -> None: if tag_filter.len() == 0: outgoing_tags[row.content_id] = set(row.link_to_tags or []) else: - outgoing_tags[row.content_id] = tag_filter.intersection(set(row.link_to_tags or [])) + outgoing_tags[row.content_id] = tag_filter.intersection( + set(row.link_to_tags or []) + ) helper.add_candidates(candidates) if initial_roots: diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 84b66abce..7cf4efd18 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -211,10 +211,14 @@ def test_mmr_traversal( results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"]) assert _result_ids(results) == ["v1", "v3", "v2"] - results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("explicit", "link"))) + results = gs.mmr_traversal_search( + "0.0", k=2, fetch_k=2, tag_filter={("explicit", "link")} + ) assert _result_ids(results) == ["v0", "v2"] - results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("no", "match"))) + results = gs.mmr_traversal_search( + "0.0", k=2, fetch_k=2, tag_filter={("no", "match")} + ) assert _result_ids(results) == [] From 192e8e4e21485ac51f073f9bcb9be2166bacb2f7 Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Fri, 16 Aug 2024 10:32:46 -0400 Subject: [PATCH 3/5] WIP --- .../ragstack_knowledge_store/graph_store.py | 8 +++++++- .../tests/integration_tests/test_graph_store.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 047984567..d3fb369c7 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -570,6 +570,7 @@ def traversal_search( k: int = 4, depth: int = 1, metadata_filter: dict[str, Any] = {}, # noqa: B006 + tag_filter: set[tuple[str, str]], ) -> Iterable[Node]: """Retrieve documents from this knowledge store. @@ -583,6 +584,7 @@ def traversal_search( Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. metadata_filter: Optional metadata to filter the results. + tag_filter: Optional tags to filter graph edges to be traversed. Returns: Collection of retrieved documents. @@ -647,7 +649,11 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: # given depth, so we don't fetch it again # (unless we find it an earlier depth) visited_tags[(kind, value)] = d - outgoing_tags.add((kind, value)) + if ( + tag_filter.len() == 0 + or (kind, value) in tag_filter + ): + outgoing_tags.add((kind, value)) if outgoing_tags: # If there are new tags to visit at the next depth, query for the diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 7cf4efd18..6bb8ba506 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -292,6 +292,14 @@ def test_write_retrieve_keywords( results = gs.traversal_search("Earth", k=1, depth=1) assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"} + results = gs.traversal_search( + "Earth", k=1, depth=1, tag_filter={("parent", "parent")} + ) + assert set(_result_ids(results)) == {"doc2", "greetings"} + + results = gs.traversal_search("Earth", k=1, depth=1, tag_filter={("no", "match")}) + assert _result_ids(results) == [] + def test_metadata( graph_store_factory: Callable[[MetadataIndexingType], GraphStore], From 12a857e17102fd0e242fbf8991cbe8bb17845000 Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Fri, 16 Aug 2024 13:22:03 -0400 Subject: [PATCH 4/5] default --- libs/knowledge-store/ragstack_knowledge_store/graph_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index d3fb369c7..2bc8fc0a3 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -367,7 +367,7 @@ def mmr_traversal_search( lambda_mult: float = 0.5, score_threshold: float = float("-inf"), metadata_filter: dict[str, Any] = {}, # noqa: B006 - tag_filter: set[tuple[str, str]], + tag_filter: set[tuple[str, str]] = {}, ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -570,7 +570,7 @@ def traversal_search( k: int = 4, depth: int = 1, metadata_filter: dict[str, Any] = {}, # noqa: B006 - tag_filter: set[tuple[str, str]], + tag_filter: set[tuple[str, str]] = {}, ) -> Iterable[Node]: """Retrieve documents from this knowledge store. From c857513c242242321c58d9aaf33448eb60aa9f8c Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Fri, 16 Aug 2024 13:27:26 -0400 Subject: [PATCH 5/5] fix len --- .../ragstack_knowledge_store/graph_store.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 2bc8fc0a3..0e391214c 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -446,7 +446,7 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - if tag_filter.len() == 0: + if len(tag_filter) == 0: outgoing_tags[adjacent.target_content_id] = ( adjacent.target_link_to_tags ) @@ -481,7 +481,7 @@ def fetch_initial_candidates() -> None: for row in fetched: if row.content_id not in outgoing_tags: candidates[row.content_id] = row.text_embedding - if tag_filter.len() == 0: + if len(tag_filter) == 0: outgoing_tags[row.content_id] = set(row.link_to_tags or []) else: outgoing_tags[row.content_id] = tag_filter.intersection( @@ -534,7 +534,7 @@ def fetch_initial_candidates() -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - if tag_filter.len() == 0: + if len(tag_filter) == 0: outgoing_tags[adjacent.target_content_id] = ( adjacent.target_link_to_tags ) @@ -650,7 +650,7 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: # (unless we find it an earlier depth) visited_tags[(kind, value)] = d if ( - tag_filter.len() == 0 + len(tag_filter) == 0 or (kind, value) in tag_filter ): outgoing_tags.add((kind, value))