Skip to content

Removes | char from type hints to be compatible with Python 3.9 #387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
def __init__(
self,
llm: LLMInterface,
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
create_lexical_graph: bool = True,
on_error: OnError = OnError.RAISE,
max_concurrency: int = 5,
Expand Down Expand Up @@ -291,7 +291,7 @@ async def run(
chunks: TextChunks,
document_info: Optional[DocumentInfo] = None,
lexical_graph_config: Optional[LexicalGraphConfig] = None,
schema: Union[GraphSchema, None] = None,
schema: Optional[GraphSchema] = None,
examples: str = "",
**kwargs: Any,
) -> Neo4jGraph:
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from __future__ import annotations

import inspect
from typing import Any, get_type_hints
from typing import Any, Union, get_type_hints

from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.utils.validation import issubclass_safe


Expand Down Expand Up @@ -80,8 +80,8 @@ class Component(metaclass=ComponentMeta):
# these variables are filled by the metaclass
# added here for the type checker
# DO NOT CHANGE
component_inputs: dict[str, dict[str, str | bool]]
component_outputs: dict[str, dict[str, str | bool | type]]
component_inputs: dict[str, dict[str, Union[str, bool]]]
component_outputs: dict[str, dict[str, Union[str, bool, type]]]

async def run(self, *args: Any, **kwargs: Any) -> DataModel:
"""Run the component and return its result.
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/experimental/pipeline/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import logging
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel, PrivateAttr

Expand Down Expand Up @@ -58,5 +58,5 @@ def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
for param_name, param in params.items()
}

def parse(self, resolved_data: dict[str, Any] | None = None) -> Any:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Any:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,7 @@

import importlib
import logging
from typing import (
Any,
ClassVar,
Generic,
Optional,
TypeVar,
Union,
cast,
)
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union, cast

import neo4j
from pydantic import (
Expand All @@ -58,7 +50,6 @@
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.validation import issubclass_safe


logger = logging.getLogger(__name__)


Expand All @@ -73,7 +64,7 @@ class ObjectConfig(AbstractConfig, Generic[T]):
and its constructor parameters.
"""

class_: str | None = Field(default=None, validate_default=True)
class_: Optional[str] = Field(default=None, validate_default=True)
"""Path to class to be instantiated."""
params_: dict[str, ParamConfig] = {}
"""Initialization parameters."""
Expand Down Expand Up @@ -128,7 +119,7 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t
raise ValueError(f"Could not find {class_name} in {module_name}")
return cast(type, klass)

def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> T:
"""Import `class_`, resolve `params_` and instantiate object."""
self._global_data = resolved_data or {}
logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
Expand Down Expand Up @@ -162,7 +153,7 @@ def validate_class(cls, class_: Any) -> str:
# not used
return "not used"

def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver:
params = self.resolve_params(self.params_)
# we know these params are there because of the required params validator
uri = params.pop("uri")
Expand All @@ -185,7 +176,7 @@ class Neo4jDriverType(RootModel): # type: ignore[type-arg]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver:
if isinstance(self.root, neo4j.Driver):
return self.root
# self.root is a Neo4jDriverConfig object
Expand All @@ -212,7 +203,7 @@ class LLMType(RootModel): # type: ignore[type-arg]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> LLMInterface:
if isinstance(self.root, LLMInterface):
return self.root
return self.root.parse(resolved_data)
Expand All @@ -238,7 +229,7 @@ class EmbedderType(RootModel): # type: ignore[type-arg]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Embedder:
if isinstance(self.root, Embedder):
return self.root
return self.root.parse(resolved_data)
Expand Down Expand Up @@ -266,7 +257,7 @@ class ComponentType(RootModel): # type: ignore[type-arg]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> Component:
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Component:
if isinstance(self.root, Component):
return self.root
return self.root.parse(resolved_data)
Expand Down
13 changes: 9 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import (
Annotated,
Any,
Optional,
Union,
)

Expand All @@ -37,7 +38,6 @@
from typing_extensions import Self

from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.utils.file_handler import FileHandler
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
AbstractPipelineConfig,
PipelineConfig,
Expand All @@ -48,6 +48,7 @@
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition
from neo4j_graphrag.utils.file_handler import FileHandler
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)
Expand All @@ -70,7 +71,9 @@ class PipelineConfigWrapper(BaseModel):
Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)],
] = Field(discriminator=Discriminator(_get_discriminator_value))

def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
def parse(
self, resolved_data: Optional[dict[str, Any]] = None
) -> PipelineDefinition:
logger.debug("PIPELINE_CONFIG: start parsing config...")
return self.config.parse(resolved_data)

Expand All @@ -90,7 +93,7 @@ class PipelineRunner:
def __init__(
self,
pipeline_definition: PipelineDefinition,
config: AbstractPipelineConfig | None = None,
config: Optional[AbstractPipelineConfig] = None,
do_cleaning: bool = False,
) -> None:
self.config = config
Expand All @@ -100,7 +103,9 @@ def __init__(

@classmethod
def from_config(
cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
cls,
config: Union[AbstractPipelineConfig, dict[str, Any]],
do_cleaning: bool = False,
) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
logger.debug(
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import uuid
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import (
PipelineDefinitionError,
PipelineMissingDependencyError,
PipelineStatusUpdateError,
)
from neo4j_graphrag.experimental.pipeline.notification import EventNotifier
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.types.orchestration import (
RunResult,
RunStatus,
Expand Down Expand Up @@ -235,7 +235,7 @@ async def get_component_inputs(
return component_inputs

async def add_result_for_component(
self, name: str, result: dict[str, Any] | None, is_final: bool = False
self, name: str, result: Optional[dict[str, Any]], is_final: bool = False
) -> None:
"""This is where we save the results in the result store and, optionally,
in the final result store.
Expand Down
21 changes: 10 additions & 11 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import logging
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional, AsyncGenerator
import asyncio
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils.logging import prettify

Expand All @@ -36,27 +36,26 @@
from neo4j_graphrag.experimental.pipeline.exceptions import (
PipelineDefinitionError,
)
from neo4j_graphrag.experimental.pipeline.notification import (
Event,
EventCallbackProtocol,
EventType,
PipelineEvent,
)
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
PipelineEdge,
PipelineGraph,
PipelineNode,
)
from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.types.definitions import (
ComponentDefinition,
ConnectionDefinition,
PipelineDefinition,
)
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.notification import (
EventCallbackProtocol,
Event,
PipelineEvent,
EventType,
)


logger = logging.getLogger(__name__)

Expand All @@ -79,7 +78,7 @@ def __init__(self, name: str, component: Component):

async def execute(
self, context: RunContext, inputs: dict[str, Any]
) -> RunResult | None:
) -> Optional[RunResult]:
"""Execute the task

Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/experimental/pipeline/pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Any, Generic, Optional, TypeVar
from typing import Any, Generic, Optional, TypeVar, Union


class PipelineNode:
Expand Down Expand Up @@ -124,7 +124,7 @@ def previous_edges(self, node: str) -> list[GenericEdgeType]:
res.append(edge)
return res

def __contains__(self, node: GenericNodeType | str) -> bool:
def __contains__(self, node: Union[GenericNodeType, str]) -> bool:
if isinstance(node, str):
return node in self._nodes
return node.name in self._nodes
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def search(
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
response_fallback: str | None = None,
return_context: Optional[bool] = None,
response_fallback: Optional[str] = None,
) -> RagResultModel:
"""
.. warning::
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Union
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, field_validator

Expand Down Expand Up @@ -43,11 +43,11 @@ class RagSearchModel(BaseModel):
examples: str = ""
retriever_config: dict[str, Any] = {}
return_context: bool = False
response_fallback: str | None = None
response_fallback: Optional[str] = None


class RagResultModel(BaseModel):
answer: str
retriever_result: Union[RetrieverResult, None] = None
retriever_result: Optional[RetrieverResult] = None

model_config = ConfigDict(arbitrary_types_allowed=True)
14 changes: 7 additions & 7 deletions tests/unit/experimental/components/test_graph_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from unittest.mock import patch, Mock, ANY

import pytest
from typing import Any, Optional
from unittest.mock import ANY, Mock, patch

import pytest
from neo4j_graphrag.experimental.components.graph_pruning import (
GraphPruning,
GraphPruningResult,
PruningStats,
)
from neo4j_graphrag.experimental.components.schema import (
GraphSchema,
NodeType,
PropertyType,
RelationshipType,
GraphSchema,
)
from neo4j_graphrag.experimental.components.types import (
LexicalGraphConfig,
Neo4jGraph,
Neo4jNode,
Neo4jRelationship,
Neo4jGraph,
LexicalGraphConfig,
)


Expand Down Expand Up @@ -369,7 +369,7 @@ def test_graph_pruning_validate_relationship(
additional_relationship_types: bool,
patterns: tuple[tuple[str, str, str], ...],
additional_patterns: bool,
expected_relationship: str | None,
expected_relationship: Optional[str],
request: pytest.FixtureRequest,
) -> None:
relationship_obj = request.getfixturevalue(relationship)
Expand Down
Loading