From 1b9c9bf160e6d467b487db2f37c997ea6710840b Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 17 Jul 2025 14:39:04 +0100 Subject: [PATCH 1/3] Removes | char from type hints to be compatible with Python 3.9 --- .../components/entity_relation_extractor.py | 2 +- .../experimental/pipeline/component.py | 8 +++--- .../experimental/pipeline/config/base.py | 4 +-- .../pipeline/config/object_config.py | 25 ++++++------------- .../experimental/pipeline/config/runner.py | 12 ++++++--- .../experimental/pipeline/orchestrator.py | 6 ++--- .../experimental/pipeline/pipeline.py | 21 ++++++++-------- .../experimental/pipeline/pipeline_graph.py | 4 +-- src/neo4j_graphrag/generation/graphrag.py | 4 +-- src/neo4j_graphrag/generation/types.py | 2 +- .../components/test_graph_pruning.py | 14 +++++------ 11 files changed, 48 insertions(+), 54 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index fd5fb276e..60bc2b647 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -187,7 +187,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor): def __init__( self, llm: LLMInterface, - prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(), create_lexical_graph: bool = True, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 39a2816ef..ebbf7e36c 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -15,12 +15,12 @@ from __future__ import annotations import inspect -from typing import Any, get_type_hints +from typing import Any, Union, get_type_hints from pydantic import BaseModel -from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.utils.validation import issubclass_safe @@ -80,8 +80,8 @@ class Component(metaclass=ComponentMeta): # these variables are filled by the metaclass # added here for the type checker # DO NOT CHANGE - component_inputs: dict[str, dict[str, str | bool]] - component_outputs: dict[str, dict[str, str | bool | type]] + component_inputs: dict[str, dict[str, Union[str, bool]]] + component_outputs: dict[str, dict[str, Union[str, bool, type]]] async def run(self, *args: Any, **kwargs: Any) -> DataModel: """Run the component and return its result. diff --git a/src/neo4j_graphrag/experimental/pipeline/config/base.py b/src/neo4j_graphrag/experimental/pipeline/config/base.py index 665a56d0f..4a287a75d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/base.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/base.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, Union from pydantic import BaseModel, PrivateAttr @@ -58,5 +58,5 @@ def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: for param_name, param in params.items() } - def parse(self, resolved_data: dict[str, Any] | None = None) -> Any: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Any: raise NotImplementedError() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py index 95d69888d..4fd9571c7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -31,15 +31,7 @@ import importlib import logging -from typing import ( - Any, - ClassVar, - Generic, - Optional, - TypeVar, - Union, - cast, -) +from typing import Any, ClassVar, Generic, Optional, TypeVar, Union, cast import neo4j from pydantic import ( @@ -58,7 +50,6 @@ from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.utils.validation import issubclass_safe - logger = logging.getLogger(__name__) @@ -73,7 +64,7 @@ class ObjectConfig(AbstractConfig, Generic[T]): and its constructor parameters. """ - class_: str | None = Field(default=None, validate_default=True) + class_: Union[str, None] = Field(default=None, validate_default=True) """Path to class to be instantiated.""" params_: dict[str, ParamConfig] = {} """Initialization parameters.""" @@ -128,7 +119,7 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t raise ValueError(f"Could not find {class_name} in {module_name}") return cast(type, klass) - def parse(self, resolved_data: dict[str, Any] | None = None) -> T: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> T: """Import `class_`, resolve `params_` and instantiate object.""" self._global_data = resolved_data or {} logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}") @@ -162,7 +153,7 @@ def validate_class(cls, class_: Any) -> str: # not used return "not used" - def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver: params = self.resolve_params(self.params_) # we know these params are there because of the required params validator uri = params.pop("uri") @@ -185,7 +176,7 @@ class Neo4jDriverType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver: if isinstance(self.root, neo4j.Driver): return self.root # self.root is a Neo4jDriverConfig object @@ -212,7 +203,7 @@ class LLMType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> LLMInterface: if isinstance(self.root, LLMInterface): return self.root return self.root.parse(resolved_data) @@ -238,7 +229,7 @@ class EmbedderType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Embedder: if isinstance(self.root, Embedder): return self.root return self.root.parse(resolved_data) @@ -266,7 +257,7 @@ class ComponentType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: dict[str, Any] | None = None) -> Component: + def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Component: if isinstance(self.root, Component): return self.root return self.root.parse(resolved_data) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index fb0544bce..da9d8c598 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -37,7 +37,6 @@ from typing_extensions import Self from neo4j_graphrag.experimental.pipeline import Pipeline -from neo4j_graphrag.utils.file_handler import FileHandler from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, PipelineConfig, @@ -48,6 +47,7 @@ from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition +from neo4j_graphrag.utils.file_handler import FileHandler from neo4j_graphrag.utils.logging import prettify logger = logging.getLogger(__name__) @@ -70,7 +70,9 @@ class PipelineConfigWrapper(BaseModel): Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], ] = Field(discriminator=Discriminator(_get_discriminator_value)) - def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: + def parse( + self, resolved_data: Union[dict[str, Any], None] = None + ) -> PipelineDefinition: logger.debug("PIPELINE_CONFIG: start parsing config...") return self.config.parse(resolved_data) @@ -90,7 +92,7 @@ class PipelineRunner: def __init__( self, pipeline_definition: PipelineDefinition, - config: AbstractPipelineConfig | None = None, + config: Union[AbstractPipelineConfig, None] = None, do_cleaning: bool = False, ) -> None: self.config = config @@ -100,7 +102,9 @@ def __init__( @classmethod def from_config( - cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False + cls, + config: Union[AbstractPipelineConfig, dict[str, Any]], + do_cleaning: bool = False, ) -> Self: wrapper = PipelineConfigWrapper.model_validate({"config": config}) logger.debug( diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index de9468bf7..8917da187 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -19,15 +19,15 @@ import uuid import warnings from functools import partial -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Union -from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import ( PipelineDefinitionError, PipelineMissingDependencyError, PipelineStatusUpdateError, ) from neo4j_graphrag.experimental.pipeline.notification import EventNotifier +from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.types.orchestration import ( RunResult, RunStatus, @@ -235,7 +235,7 @@ async def get_component_inputs( return component_inputs async def add_result_for_component( - self, name: str, result: dict[str, Any] | None, is_final: bool = False + self, name: str, result: Union[dict[str, Any], None], is_final: bool = False ) -> None: """This is where we save the results in the result store and, optionally, in the final result store. diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 91dab34ee..957355869 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -14,12 +14,12 @@ # limitations under the License. from __future__ import annotations +import asyncio import logging import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator -import asyncio +from typing import Any, AsyncGenerator, Optional, Union from neo4j_graphrag.utils.logging import prettify @@ -36,6 +36,12 @@ from neo4j_graphrag.experimental.pipeline.exceptions import ( PipelineDefinitionError, ) +from neo4j_graphrag.experimental.pipeline.notification import ( + Event, + EventCallbackProtocol, + EventType, + PipelineEvent, +) from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.pipeline_graph import ( PipelineEdge, @@ -43,20 +49,13 @@ PipelineNode, ) from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore +from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.types.definitions import ( ComponentDefinition, ConnectionDefinition, PipelineDefinition, ) from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult -from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from neo4j_graphrag.experimental.pipeline.notification import ( - EventCallbackProtocol, - Event, - PipelineEvent, - EventType, -) - logger = logging.getLogger(__name__) @@ -79,7 +78,7 @@ def __init__(self, name: str, component: Component): async def execute( self, context: RunContext, inputs: dict[str, Any] - ) -> RunResult | None: + ) -> Union[RunResult, None]: """Execute the task Returns: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline_graph.py b/src/neo4j_graphrag/experimental/pipeline/pipeline_graph.py index 37fb2072b..591e70053 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline_graph.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline_graph.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar, Union class PipelineNode: @@ -124,7 +124,7 @@ def previous_edges(self, node: str) -> list[GenericEdgeType]: res.append(edge) return res - def __contains__(self, node: GenericNodeType | str) -> bool: + def __contains__(self, node: Union[GenericNodeType, str]) -> bool: if isinstance(node, str): return node in self._nodes return node.name in self._nodes diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 3e649cc13..c5d86d8ca 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -88,8 +88,8 @@ def search( message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, - return_context: bool | None = None, - response_fallback: str | None = None, + return_context: Union[bool, None] = None, + response_fallback: Union[str, None] = None, ) -> RagResultModel: """ .. warning:: diff --git a/src/neo4j_graphrag/generation/types.py b/src/neo4j_graphrag/generation/types.py index a03983c27..094753314 100644 --- a/src/neo4j_graphrag/generation/types.py +++ b/src/neo4j_graphrag/generation/types.py @@ -43,7 +43,7 @@ class RagSearchModel(BaseModel): examples: str = "" retriever_config: dict[str, Any] = {} return_context: bool = False - response_fallback: str | None = None + response_fallback: Union[str, None] = None class RagResultModel(BaseModel): diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index c4c779cc4..37eb106c8 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -13,27 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Any -from unittest.mock import patch, Mock, ANY -import pytest +from typing import Any, Union +from unittest.mock import ANY, Mock, patch +import pytest from neo4j_graphrag.experimental.components.graph_pruning import ( GraphPruning, GraphPruningResult, PruningStats, ) from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, NodeType, PropertyType, RelationshipType, - GraphSchema, ) from neo4j_graphrag.experimental.components.types import ( + LexicalGraphConfig, + Neo4jGraph, Neo4jNode, Neo4jRelationship, - Neo4jGraph, - LexicalGraphConfig, ) @@ -369,7 +369,7 @@ def test_graph_pruning_validate_relationship( additional_relationship_types: bool, patterns: tuple[tuple[str, str, str], ...], additional_patterns: bool, - expected_relationship: str | None, + expected_relationship: Union[str, None], request: pytest.FixtureRequest, ) -> None: relationship_obj = request.getfixturevalue(relationship) From 9a3182b95d56beb24978c366be35cce33775f24e Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 17 Jul 2025 16:18:40 +0100 Subject: [PATCH 2/3] Replaced Union types with Optional --- .../components/entity_relation_extractor.py | 2 +- .../experimental/pipeline/config/base.py | 4 ++-- .../experimental/pipeline/config/object_config.py | 14 +++++++------- .../experimental/pipeline/config/runner.py | 3 ++- .../experimental/pipeline/orchestrator.py | 4 ++-- .../experimental/pipeline/pipeline.py | 4 ++-- src/neo4j_graphrag/generation/graphrag.py | 4 ++-- src/neo4j_graphrag/generation/types.py | 6 +++--- .../experimental/components/test_graph_pruning.py | 4 ++-- tests/unit/retrievers/test_base.py | 4 ++-- 10 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 60bc2b647..9b83a790d 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -291,7 +291,7 @@ async def run( chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, - schema: Union[GraphSchema, None] = None, + schema: Optional[GraphSchema] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/base.py b/src/neo4j_graphrag/experimental/pipeline/config/base.py index 4a287a75d..586372a9d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/base.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/base.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Union +from typing import Any, Optional from pydantic import BaseModel, PrivateAttr @@ -58,5 +58,5 @@ def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: for param_name, param in params.items() } - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Any: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Any: raise NotImplementedError() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py index 4fd9571c7..0533be3f8 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -64,7 +64,7 @@ class ObjectConfig(AbstractConfig, Generic[T]): and its constructor parameters. """ - class_: Union[str, None] = Field(default=None, validate_default=True) + class_: Optional[str] = Field(default=None, validate_default=True) """Path to class to be instantiated.""" params_: dict[str, ParamConfig] = {} """Initialization parameters.""" @@ -119,7 +119,7 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t raise ValueError(f"Could not find {class_name} in {module_name}") return cast(type, klass) - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> T: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> T: """Import `class_`, resolve `params_` and instantiate object.""" self._global_data = resolved_data or {} logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}") @@ -153,7 +153,7 @@ def validate_class(cls, class_: Any) -> str: # not used return "not used" - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver: params = self.resolve_params(self.params_) # we know these params are there because of the required params validator uri = params.pop("uri") @@ -176,7 +176,7 @@ class Neo4jDriverType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver: if isinstance(self.root, neo4j.Driver): return self.root # self.root is a Neo4jDriverConfig object @@ -203,7 +203,7 @@ class LLMType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> LLMInterface: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> LLMInterface: if isinstance(self.root, LLMInterface): return self.root return self.root.parse(resolved_data) @@ -229,7 +229,7 @@ class EmbedderType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Embedder: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Embedder: if isinstance(self.root, Embedder): return self.root return self.root.parse(resolved_data) @@ -257,7 +257,7 @@ class ComponentType(RootModel): # type: ignore[type-arg] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Component: + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Component: if isinstance(self.root, Component): return self.root return self.root.parse(resolved_data) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index da9d8c598..af1367265 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -24,6 +24,7 @@ from typing import ( Annotated, Any, + Optional, Union, ) @@ -71,7 +72,7 @@ class PipelineConfigWrapper(BaseModel): ] = Field(discriminator=Discriminator(_get_discriminator_value)) def parse( - self, resolved_data: Union[dict[str, Any], None] = None + self, resolved_data: Optional[dict[str, Any]] = None ) -> PipelineDefinition: logger.debug("PIPELINE_CONFIG: start parsing config...") return self.config.parse(resolved_data) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index 8917da187..b5536b537 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -19,7 +19,7 @@ import uuid import warnings from functools import partial -from typing import TYPE_CHECKING, Any, AsyncGenerator, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional from neo4j_graphrag.experimental.pipeline.exceptions import ( PipelineDefinitionError, @@ -235,7 +235,7 @@ async def get_component_inputs( return component_inputs async def add_result_for_component( - self, name: str, result: Union[dict[str, Any], None], is_final: bool = False + self, name: str, result: Optional[dict[str, Any]], is_final: bool = False ) -> None: """This is where we save the results in the result store and, optionally, in the final result store. diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 957355869..150e70789 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -19,7 +19,7 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, AsyncGenerator, Optional, Union +from typing import Any, AsyncGenerator, Optional from neo4j_graphrag.utils.logging import prettify @@ -78,7 +78,7 @@ def __init__(self, name: str, component: Component): async def execute( self, context: RunContext, inputs: dict[str, Any] - ) -> Union[RunResult, None]: + ) -> Optional[RunResult]: """Execute the task Returns: diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index c5d86d8ca..08f08a368 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -88,8 +88,8 @@ def search( message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, - return_context: Union[bool, None] = None, - response_fallback: Union[str, None] = None, + return_context: Optional[bool] = None, + response_fallback: Optional[str] = None, ) -> RagResultModel: """ .. warning:: diff --git a/src/neo4j_graphrag/generation/types.py b/src/neo4j_graphrag/generation/types.py index 094753314..8569cb9e8 100644 --- a/src/neo4j_graphrag/generation/types.py +++ b/src/neo4j_graphrag/generation/types.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Union +from typing import Any, Optional from pydantic import BaseModel, ConfigDict, field_validator @@ -43,11 +43,11 @@ class RagSearchModel(BaseModel): examples: str = "" retriever_config: dict[str, Any] = {} return_context: bool = False - response_fallback: Union[str, None] = None + response_fallback: Optional[str] = None class RagResultModel(BaseModel): answer: str - retriever_result: Union[RetrieverResult, None] = None + retriever_result: Optional[RetrieverResult] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 37eb106c8..4aee8949f 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Union +from typing import Any, Optional from unittest.mock import ANY, Mock, patch import pytest @@ -369,7 +369,7 @@ def test_graph_pruning_validate_relationship( additional_relationship_types: bool, patterns: tuple[tuple[str, str, str], ...], additional_patterns: bool, - expected_relationship: Union[str, None], + expected_relationship: Optional[str], request: pytest.FixtureRequest, ) -> None: relationship_obj = request.getfixturevalue(relationship) diff --git a/tests/unit/retrievers/test_base.py b/tests/unit/retrievers/test_base.py index 252283c5e..715252620 100644 --- a/tests/unit/retrievers/test_base.py +++ b/tests/unit/retrievers/test_base.py @@ -15,7 +15,7 @@ from __future__ import annotations # Reminder: May be removed after Python 3.9 is EOL. import inspect -from typing import Any, Union +from typing import Any, Optional from unittest.mock import MagicMock, patch import pytest @@ -41,7 +41,7 @@ def test_retriever_version_support( mock_get_version: MagicMock, driver: MagicMock, db_version: tuple[tuple[int, ...], bool], - expected_exception: Union[type[ValueError], None], + expected_exception: Optional[type[ValueError]], ) -> None: mock_get_version.return_value = db_version From bdaf062c3b9f7270b5f3d208af5f5941444522f9 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 17 Jul 2025 16:23:24 +0100 Subject: [PATCH 3/3] Missed one --- src/neo4j_graphrag/experimental/pipeline/config/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index af1367265..376de2973 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -93,7 +93,7 @@ class PipelineRunner: def __init__( self, pipeline_definition: PipelineDefinition, - config: Union[AbstractPipelineConfig, None] = None, + config: Optional[AbstractPipelineConfig] = None, do_cleaning: bool = False, ) -> None: self.config = config