Skip to content

Commit 7d65d5e

Browse files
authored
Harden search filters against Cypher injection (#1312)
* harden search filter inputs * validate entity node labels on save * tighten security regression coverage
1 parent b10b488 commit 7d65d5e

File tree

11 files changed

+234
-7
lines changed

11 files changed

+234
-7
lines changed

graphiti_core/driver/falkordb_driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from graphiti_core.driver.operations.saga_node_ops import SagaNodeOperations
6767
from graphiti_core.driver.operations.search_ops import SearchOperations
6868
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
69+
from graphiti_core.helpers import validate_group_ids
6970
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
7071

7172
logger = logging.getLogger(__name__)
@@ -397,6 +398,8 @@ def build_fulltext_query(
397398
- AND is implicit with space: (@group_id:value) (text)
398399
- OR uses pipe within parentheses: (@group_id:value1|value2)
399400
"""
401+
validate_group_ids(group_ids)
402+
400403
if group_ids is None or len(group_ids) == 0:
401404
group_filter = ''
402405
else:

graphiti_core/driver/neo4j/operations/search_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_relationships_query,
3333
get_vector_cosine_func_query,
3434
)
35-
from graphiti_core.helpers import lucene_sanitize
35+
from graphiti_core.helpers import lucene_sanitize, validate_group_ids
3636
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
3737
from graphiti_core.models.nodes.node_db_queries import (
3838
COMMUNITY_NODE_RETURN,
@@ -56,6 +56,8 @@ def _build_neo4j_fulltext_query(
5656
group_ids: list[str] | None = None,
5757
max_query_length: int = MAX_QUERY_LENGTH,
5858
) -> str:
59+
validate_group_ids(group_ids)
60+
5961
group_ids_filter_list = [f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
6062
group_ids_filter = ''
6163
for f in group_ids_filter_list:

graphiti_core/errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,15 @@ class GroupIdValidationError(GraphitiError):
8181
def __init__(self, group_id: str):
8282
self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores'
8383
super().__init__(self.message)
84+
85+
86+
class NodeLabelValidationError(GraphitiError, ValueError):
87+
"""Raised when a node label contains invalid characters."""
88+
89+
def __init__(self, node_labels: list[str]):
90+
label_list = ', '.join(f'"{label}"' for label in node_labels)
91+
self.message = (
92+
'node_labels must start with a letter or underscore and contain only '
93+
f'alphanumeric characters or underscores: {label_list}'
94+
)
95+
super().__init__(self.message)

graphiti_core/helpers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
from pydantic import BaseModel
2929

3030
from graphiti_core.driver.driver import GraphProvider
31-
from graphiti_core.errors import GroupIdValidationError
31+
from graphiti_core.errors import GroupIdValidationError, NodeLabelValidationError
3232

3333
load_dotenv()
3434

35+
SAFE_CYPHER_IDENTIFIER_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
36+
3537
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
3638
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
3739
DEFAULT_PAGE_LIMIT = 20
@@ -157,6 +159,33 @@ def validate_group_id(group_id: str | None) -> bool:
157159
return True
158160

159161

162+
def validate_group_ids(group_ids: list[str] | None) -> bool:
163+
"""Validate a list of group ids used by search paths."""
164+
165+
if group_ids is None:
166+
return True
167+
168+
for group_id in group_ids:
169+
validate_group_id(group_id)
170+
171+
return True
172+
173+
174+
def validate_node_labels(node_labels: list[str] | None) -> bool:
175+
"""Validate that node labels are safe to interpolate into Cypher label expressions."""
176+
177+
if not node_labels:
178+
return True
179+
180+
invalid_labels = [
181+
label for label in node_labels if not SAFE_CYPHER_IDENTIFIER_PATTERN.match(label)
182+
]
183+
if invalid_labels:
184+
raise NodeLabelValidationError(invalid_labels)
185+
186+
return True
187+
188+
160189
def validate_excluded_entity_types(
161190
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
162191
) -> bool:

graphiti_core/models/nodes/node_db_queries.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
from typing import Any
1818

1919
from graphiti_core.driver.driver import GraphProvider
20+
from graphiti_core.helpers import validate_node_labels
21+
22+
23+
def _validate_entity_labels(labels: str | list[str]) -> list[str]:
24+
resolved_labels = labels.split(':') if isinstance(labels, str) else labels
25+
filtered_labels = [label for label in resolved_labels if label]
26+
validate_node_labels(filtered_labels)
27+
return filtered_labels
2028

2129

2230
def get_episode_node_save_query(provider: GraphProvider) -> str:
@@ -127,6 +135,9 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
127135

128136

129137
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
138+
validated_labels = _validate_entity_labels(labels)
139+
labels = ':'.join(validated_labels)
140+
130141
match provider:
131142
case GraphProvider.FALKORDB:
132143
return f"""
@@ -152,7 +163,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
152163
"""
153164
case GraphProvider.NEPTUNE:
154165
label_subquery = ''
155-
for label in labels.split(':'):
166+
for label in validated_labels:
156167
label_subquery += f' SET n:{label}\n'
157168
return f"""
158169
MERGE (n:Entity {{uuid: $entity_data.uuid}})
@@ -183,6 +194,9 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
183194
def get_entity_node_save_bulk_query(
184195
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
185196
) -> str | Any:
197+
for node in nodes:
198+
_validate_entity_labels(node.get('labels', []))
199+
186200
match provider:
187201
case GraphProvider.FALKORDB:
188202
queries = []

graphiti_core/nodes.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any
2424
from uuid import uuid4
2525

26-
from pydantic import BaseModel, Field
26+
from pydantic import BaseModel, ConfigDict, Field, field_validator
2727
from typing_extensions import LiteralString
2828

2929
from graphiti_core.driver.driver import (
@@ -32,7 +32,7 @@
3232
)
3333
from graphiti_core.embedder import EmbedderClient
3434
from graphiti_core.errors import NodeNotFoundError
35-
from graphiti_core.helpers import parse_db_date
35+
from graphiti_core.helpers import parse_db_date, validate_node_labels
3636
from graphiti_core.models.nodes.node_db_queries import (
3737
COMMUNITY_NODE_RETURN,
3838
COMMUNITY_NODE_RETURN_NEPTUNE,
@@ -94,6 +94,14 @@ class Node(BaseModel, ABC):
9494
labels: list[str] = Field(default_factory=list)
9595
created_at: datetime = Field(default_factory=lambda: utc_now())
9696

97+
model_config = ConfigDict(validate_assignment=True)
98+
99+
@field_validator('labels')
100+
@classmethod
101+
def validate_labels(cls, value: list[str]) -> list[str]:
102+
validate_node_labels(value)
103+
return value
104+
97105
@abstractmethod
98106
async def save(self, driver: GraphDriver): ...
99107

graphiti_core/search/search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from graphiti_core.embedder.client import EMBEDDING_DIM
2525
from graphiti_core.errors import SearchRerankerError
2626
from graphiti_core.graphiti_types import GraphitiClients
27-
from graphiti_core.helpers import semaphore_gather
27+
from graphiti_core.helpers import semaphore_gather, validate_group_ids
2828
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
2929
from graphiti_core.search.search_config import (
3030
DEFAULT_SEARCH_LIMIT,
@@ -77,6 +77,7 @@ async def search(
7777
driver: GraphDriver | None = None,
7878
) -> SearchResults:
7979
start = time()
80+
validate_group_ids(group_ids)
8081

8182
driver = driver or clients.driver
8283
embedder = clients.embedder

graphiti_core/search/search_filters.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from enum import Enum
1919
from typing import Any
2020

21-
from pydantic import BaseModel, Field
21+
from pydantic import BaseModel, Field, field_validator
2222

2323
from graphiti_core.driver.driver import GraphProvider
24+
from graphiti_core.helpers import validate_node_labels
2425

2526

2627
class ComparisonOperator(Enum):
@@ -65,6 +66,12 @@ class SearchFilters(BaseModel):
6566
edge_uuids: list[str] | None = Field(default=None)
6667
property_filters: list[PropertyFilter] | None = Field(default=None)
6768

69+
@field_validator('node_labels')
70+
@classmethod
71+
def validate_node_label_filters(cls, value: list[str] | None) -> list[str] | None:
72+
validate_node_labels(value)
73+
return value
74+
6875

6976
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
7077
mapping = {
@@ -84,6 +91,8 @@ def node_search_filter_query_constructor(
8491
filter_params: dict[str, Any] = {}
8592

8693
if filters.node_labels is not None:
94+
# Defense-in-depth for model_construct()/other validation bypasses.
95+
validate_node_labels(filters.node_labels)
8796
if provider == GraphProvider.KUZU:
8897
node_label_filter = 'list_has_all(n.labels, $labels)'
8998
filter_params['labels'] = filters.node_labels
@@ -125,6 +134,8 @@ def edge_search_filter_query_constructor(
125134
filter_params['edge_uuids'] = filters.edge_uuids
126135

127136
if filters.node_labels is not None:
137+
# Defense-in-depth for model_construct()/other validation bypasses.
138+
validate_node_labels(filters.node_labels)
128139
if provider == GraphProvider.KUZU:
129140
node_label_filter = (
130141
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'

graphiti_core/search/search_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
lucene_sanitize,
3838
normalize_l2,
3939
semaphore_gather,
40+
validate_group_ids,
4041
)
4142
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
4243
from graphiti_core.models.nodes.node_db_queries import (
@@ -82,6 +83,8 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f
8283

8384

8485
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
86+
validate_group_ids(group_ids)
87+
8588
if driver.provider == GraphProvider.KUZU:
8689
# Kuzu only supports simple queries.
8790
if len(query.split(' ')) > MAX_QUERY_LENGTH:

tests/test_node_label_security.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from graphiti_core.driver.driver import GraphProvider
5+
from graphiti_core.errors import NodeLabelValidationError
6+
from graphiti_core.models.nodes.node_db_queries import (
7+
get_entity_node_save_bulk_query,
8+
get_entity_node_save_query,
9+
)
10+
from graphiti_core.nodes import EntityNode
11+
12+
13+
def test_entity_node_rejects_unsafe_labels():
14+
with pytest.raises(ValidationError, match='node_labels must start with a letter or underscore'):
15+
EntityNode(
16+
name='Alice',
17+
group_id='group',
18+
labels=['Entity`) WITH n MATCH (x) DETACH DELETE x //'],
19+
)
20+
21+
22+
def test_entity_node_assignment_rejects_unsafe_labels():
23+
node = EntityNode(name='Alice', group_id='group', labels=['Person'])
24+
25+
with pytest.raises(ValidationError, match='node_labels must start with a letter or underscore'):
26+
node.labels = ['Entity`) WITH n MATCH (x) DETACH DELETE x //']
27+
28+
29+
def test_entity_node_save_query_rejects_unsafe_labels_when_validation_is_bypassed():
30+
with pytest.raises(
31+
NodeLabelValidationError, match='node_labels must start with a letter or underscore'
32+
):
33+
get_entity_node_save_query(
34+
GraphProvider.NEO4J,
35+
'Entity:Entity`) WITH n MATCH (x) DETACH DELETE x //',
36+
)
37+
38+
39+
def test_entity_node_save_bulk_query_rejects_unsafe_labels_when_validation_is_bypassed():
40+
with pytest.raises(
41+
NodeLabelValidationError, match='node_labels must start with a letter or underscore'
42+
):
43+
get_entity_node_save_bulk_query(
44+
GraphProvider.FALKORDB,
45+
[
46+
{
47+
'uuid': 'node-1',
48+
'name': 'Alice',
49+
'group_id': 'group',
50+
'summary': 'summary',
51+
'created_at': '2024-01-01T00:00:00Z',
52+
'name_embedding': [0.1, 0.2],
53+
'labels': ['Entity', 'Entity`) WITH n MATCH (x) DETACH DELETE x //'],
54+
}
55+
],
56+
)

0 commit comments

Comments
 (0)