Skip to content

Function to visualize a GraphSchema object or dict #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Fixed documentation for PdfLoader
- Fixed a bug where the `format` argument for `OllamaLLM` was not propagated to the client.

### Added

- Added `schema_visualization` function to visualize a graph schema using neo4j-viz.


## 1.9.0

Expand Down
6 changes: 6 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ SchemaFromTextExtractor
.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromTextExtractor
:members: run

schema_visualization
--------------------

.. autofunction:: neo4j_graphrag.experimental.utils.schema.schema_visualization


EntityRelationExtractor
=======================

Expand Down
22 changes: 22 additions & 0 deletions docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,28 @@ You can also save and reload the extracted schema:
restored_schema = GraphSchema.from_file("my_schema.json") # or my_schema.yaml


Schema Visualization
--------------------

It is possible to visualize a validated schema or a schema dict using the `schema_visualization` function. This function
returns a VisualizationGraph object (from the neo4j-viz package) that can visualized like this:

.. code:: python

from neo4j_graphrag.experimental.utils.schema import schema_visualization

VG = schema_visualization(schema)
html = VG.render()

# in Jupyter:
display(html)

# to save the generated HTML
with open("my_schema.html", "w") as f:
f.write(html.data)



Entity and Relation Extractor
=============================

Expand Down
639 changes: 561 additions & 78 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pyyaml = "^6.0.2"
types-pyyaml = "^6.0.12.20240917"
# optional deps
langchain-text-splitters = {version = "^0.3.0", optional = true }
neo4j-viz = {version = "^0.2.2", optional = true }
neo4j-viz = {version = "^0.4.2", optional = true }
weaviate-client = {version = "^4.6.1", optional = true }
pinecone-client = {version = "^4.1.0", optional = true }
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
Expand Down Expand Up @@ -74,7 +74,6 @@ sphinx = { version = "^7.2.6", python = "^3.9" }
langchain-openai = {version = "^0.2.2", optional = true }
langchain-huggingface = {version = "^0.1.0", optional = true }
enum-tools = {extras = ["sphinx"], version = "^0.12.0"}
neo4j-viz = "^0.2.2"

[tool.poetry.extras]
weaviate = ["weaviate-client"]
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_graphrag/experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
115 changes: 115 additions & 0 deletions src/neo4j_graphrag/experimental/utils/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union

try:
from neo4j_viz import VisualizationGraph, Node, Relationship
except ImportError:
VisualizationGraph = Node = Relationship = None # type: ignore

from neo4j_graphrag.experimental.components.schema import (
GraphSchema,
NodeType,
PropertyType,
)


def schema_visualization(
schema: Union[dict[str, Any], GraphSchema],
) -> VisualizationGraph:
"""Helper function to visualize a GraphSchema using the neo4j-viz library.

Usage:

.. code:: python

VG = schema_visualization(schema)
html = VG.render()

# in Jupyter:
display(html)

# to save the generated HTML
with open("my_schema.html", "w") as f:
f.write(html.data)
"""
if VisualizationGraph is None:
raise ImportError(
"Please install neo4j-viz to use the graph schema visualization feature: pip install neo4j-viz"
)

schema_object = GraphSchema.model_validate(schema)

def _format_property_name(p: PropertyType) -> str:
"""

Args:
p (PropertyType): the property to be formatted

Returns:
str: the property name, suffixed with '*' if the property is required

"""
return p.name + ("*" if p.required else "")

def _relationship_properties(rel_type: str) -> dict[str, str]:
"""Returns a dict {prop_name: prop_type} for all relationship properties.

Args:
rel_type (str): the relationship type

Returns:
dict[str, str]: the relationship properties {name: type} mapping for display
"""
for relationship_type in schema_object.relationship_types:
if relationship_type.label != rel_type:
continue
return {
_format_property_name(p): p.type for p in relationship_type.properties
}
return {}

def _node_properties(node_type: NodeType) -> dict[str, str]:
"""Returns a dict {prop_name: prop_type} for all node properties.

Args:
node_type (NodeType): the node type object

Returns:
dict[str, str]: the node properties {name: type} mapping for display
"""
return {_format_property_name(p): p.type for p in node_type.properties}

nodes = [
Node( # type: ignore
id=node_type.label,
caption=node_type.label,
properties=_node_properties(node_type),
)
for node_type in schema_object.node_types
]
relationships = [
Relationship( # type: ignore
source=pattern[0],
target=pattern[2],
caption=pattern[1],
properties=_relationship_properties(pattern[1]),
)
for pattern in schema_object.patterns
]

VG = VisualizationGraph(nodes=nodes, relationships=relationships)
VG.color_nodes(field="caption")
return VG
14 changes: 14 additions & 0 deletions tests/unit/experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
107 changes: 107 additions & 0 deletions tests/unit/experimental/utils/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from unittest.mock import patch

import pytest
from pydantic import ValidationError

from neo4j_viz import VisualizationGraph
from neo4j_graphrag.experimental.components.schema import GraphSchema
from neo4j_graphrag.experimental.utils.schema import schema_visualization


@pytest.fixture(scope="module")
def valid_schema_dict() -> dict[str, Any]:
return {
"node_types": [
"Location",
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING", "required": True},
{"name": "birthYear", "type": "INTEGER"},
],
},
],
"relationship_types": [
"BORN_IN",
{
"label": "KNOWS",
"properties": [
{"name": "since", "type": "LOCAL_DATETIME"},
],
},
],
"patterns": [
("Person", "BORN_IN", "Location"),
("Person", "KNOWS", "Person"),
],
}


@pytest.fixture(scope="module")
def invalid_schema_dict() -> dict[str, Any]:
return {
"node_types": [
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING", "required": True},
{"name": "birthYear", "type": "INTEGER"},
],
},
],
"relationship_types": [
"BORN_IN",
],
"patterns": [
(
"Person",
"BORN_IN",
"Location",
), # invalid pattern, "Location" node type not defined
],
}


@patch("neo4j_graphrag.experimental.utils.schema.neo4j_viz", None)
def test_schema_visualization_import_error() -> None:
with pytest.raises(ImportError):
schema_visualization({})


def test_schema_visualization_invalid_schema_dict(
invalid_schema_dict: dict[str, Any],
) -> None:
with pytest.raises(ValidationError):
schema_visualization(invalid_schema_dict)


def test_schema_visualization_valid_schema_dict(
valid_schema_dict: dict[str, Any],
) -> None:
g = schema_visualization(valid_schema_dict)
assert isinstance(g, VisualizationGraph)
assert len(g.nodes) == 2
assert len(g.relationships) == 2


def test_schema_visualization_schema_object(valid_schema_dict: dict[str, Any]) -> None:
schema = GraphSchema.model_validate(valid_schema_dict)
g = schema_visualization(schema)
assert isinstance(g, VisualizationGraph)
assert len(g.nodes) == 2
assert len(g.relationships) == 2
Loading