Skip to content

Commit 9a3182b

Browse files
committed
Replaced Union types with Optional
1 parent 1b9c9bf commit 9a3182b

File tree

10 files changed

+25
-24
lines changed

10 files changed

+25
-24
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ async def run(
291291
chunks: TextChunks,
292292
document_info: Optional[DocumentInfo] = None,
293293
lexical_graph_config: Optional[LexicalGraphConfig] = None,
294-
schema: Union[GraphSchema, None] = None,
294+
schema: Optional[GraphSchema] = None,
295295
examples: str = "",
296296
**kwargs: Any,
297297
) -> Neo4jGraph:

src/neo4j_graphrag/experimental/pipeline/config/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from typing import Any, Union
20+
from typing import Any, Optional
2121

2222
from pydantic import BaseModel, PrivateAttr
2323

@@ -58,5 +58,5 @@ def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
5858
for param_name, param in params.items()
5959
}
6060

61-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Any:
61+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Any:
6262
raise NotImplementedError()

src/neo4j_graphrag/experimental/pipeline/config/object_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ObjectConfig(AbstractConfig, Generic[T]):
6464
and its constructor parameters.
6565
"""
6666

67-
class_: Union[str, None] = Field(default=None, validate_default=True)
67+
class_: Optional[str] = Field(default=None, validate_default=True)
6868
"""Path to class to be instantiated."""
6969
params_: dict[str, ParamConfig] = {}
7070
"""Initialization parameters."""
@@ -119,7 +119,7 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t
119119
raise ValueError(f"Could not find {class_name} in {module_name}")
120120
return cast(type, klass)
121121

122-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> T:
122+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> T:
123123
"""Import `class_`, resolve `params_` and instantiate object."""
124124
self._global_data = resolved_data or {}
125125
logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
@@ -153,7 +153,7 @@ def validate_class(cls, class_: Any) -> str:
153153
# not used
154154
return "not used"
155155

156-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver:
156+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver:
157157
params = self.resolve_params(self.params_)
158158
# we know these params are there because of the required params validator
159159
uri = params.pop("uri")
@@ -176,7 +176,7 @@ class Neo4jDriverType(RootModel): # type: ignore[type-arg]
176176

177177
model_config = ConfigDict(arbitrary_types_allowed=True)
178178

179-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> neo4j.Driver:
179+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> neo4j.Driver:
180180
if isinstance(self.root, neo4j.Driver):
181181
return self.root
182182
# self.root is a Neo4jDriverConfig object
@@ -203,7 +203,7 @@ class LLMType(RootModel): # type: ignore[type-arg]
203203

204204
model_config = ConfigDict(arbitrary_types_allowed=True)
205205

206-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> LLMInterface:
206+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> LLMInterface:
207207
if isinstance(self.root, LLMInterface):
208208
return self.root
209209
return self.root.parse(resolved_data)
@@ -229,7 +229,7 @@ class EmbedderType(RootModel): # type: ignore[type-arg]
229229

230230
model_config = ConfigDict(arbitrary_types_allowed=True)
231231

232-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Embedder:
232+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Embedder:
233233
if isinstance(self.root, Embedder):
234234
return self.root
235235
return self.root.parse(resolved_data)
@@ -257,7 +257,7 @@ class ComponentType(RootModel): # type: ignore[type-arg]
257257

258258
model_config = ConfigDict(arbitrary_types_allowed=True)
259259

260-
def parse(self, resolved_data: Union[dict[str, Any], None] = None) -> Component:
260+
def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> Component:
261261
if isinstance(self.root, Component):
262262
return self.root
263263
return self.root.parse(resolved_data)

src/neo4j_graphrag/experimental/pipeline/config/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import (
2525
Annotated,
2626
Any,
27+
Optional,
2728
Union,
2829
)
2930

@@ -71,7 +72,7 @@ class PipelineConfigWrapper(BaseModel):
7172
] = Field(discriminator=Discriminator(_get_discriminator_value))
7273

7374
def parse(
74-
self, resolved_data: Union[dict[str, Any], None] = None
75+
self, resolved_data: Optional[dict[str, Any]] = None
7576
) -> PipelineDefinition:
7677
logger.debug("PIPELINE_CONFIG: start parsing config...")
7778
return self.config.parse(resolved_data)

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import uuid
2020
import warnings
2121
from functools import partial
22-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Union
22+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
2323

2424
from neo4j_graphrag.experimental.pipeline.exceptions import (
2525
PipelineDefinitionError,
@@ -235,7 +235,7 @@ async def get_component_inputs(
235235
return component_inputs
236236

237237
async def add_result_for_component(
238-
self, name: str, result: Union[dict[str, Any], None], is_final: bool = False
238+
self, name: str, result: Optional[dict[str, Any]], is_final: bool = False
239239
) -> None:
240240
"""This is where we save the results in the result store and, optionally,
241241
in the final result store.

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import warnings
2020
from collections import defaultdict
2121
from timeit import default_timer
22-
from typing import Any, AsyncGenerator, Optional, Union
22+
from typing import Any, AsyncGenerator, Optional
2323

2424
from neo4j_graphrag.utils.logging import prettify
2525

@@ -78,7 +78,7 @@ def __init__(self, name: str, component: Component):
7878

7979
async def execute(
8080
self, context: RunContext, inputs: dict[str, Any]
81-
) -> Union[RunResult, None]:
81+
) -> Optional[RunResult]:
8282
"""Execute the task
8383
8484
Returns:

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def search(
8888
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8989
examples: str = "",
9090
retriever_config: Optional[dict[str, Any]] = None,
91-
return_context: Union[bool, None] = None,
92-
response_fallback: Union[str, None] = None,
91+
return_context: Optional[bool] = None,
92+
response_fallback: Optional[str] = None,
9393
) -> RagResultModel:
9494
"""
9595
.. warning::

src/neo4j_graphrag/generation/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Union
17+
from typing import Any, Optional
1818

1919
from pydantic import BaseModel, ConfigDict, field_validator
2020

@@ -43,11 +43,11 @@ class RagSearchModel(BaseModel):
4343
examples: str = ""
4444
retriever_config: dict[str, Any] = {}
4545
return_context: bool = False
46-
response_fallback: Union[str, None] = None
46+
response_fallback: Optional[str] = None
4747

4848

4949
class RagResultModel(BaseModel):
5050
answer: str
51-
retriever_result: Union[RetrieverResult, None] = None
51+
retriever_result: Optional[RetrieverResult] = None
5252

5353
model_config = ConfigDict(arbitrary_types_allowed=True)

tests/unit/experimental/components/test_graph_pruning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Union
17+
from typing import Any, Optional
1818
from unittest.mock import ANY, Mock, patch
1919

2020
import pytest
@@ -369,7 +369,7 @@ def test_graph_pruning_validate_relationship(
369369
additional_relationship_types: bool,
370370
patterns: tuple[tuple[str, str, str], ...],
371371
additional_patterns: bool,
372-
expected_relationship: Union[str, None],
372+
expected_relationship: Optional[str],
373373
request: pytest.FixtureRequest,
374374
) -> None:
375375
relationship_obj = request.getfixturevalue(relationship)

tests/unit/retrievers/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations # Reminder: May be removed after Python 3.9 is EOL.
1616

1717
import inspect
18-
from typing import Any, Union
18+
from typing import Any, Optional
1919
from unittest.mock import MagicMock, patch
2020

2121
import pytest
@@ -41,7 +41,7 @@ def test_retriever_version_support(
4141
mock_get_version: MagicMock,
4242
driver: MagicMock,
4343
db_version: tuple[tuple[int, ...], bool],
44-
expected_exception: Union[type[ValueError], None],
44+
expected_exception: Optional[type[ValueError]],
4545
) -> None:
4646
mock_get_version.return_value = db_version
4747

0 commit comments

Comments
 (0)