diff --git a/docs/how-to-guides/ogm.md b/docs/how-to-guides/ogm.md index a75da273..8a7081c4 100644 --- a/docs/how-to-guides/ogm.md +++ b/docs/how-to-guides/ogm.md @@ -395,6 +395,37 @@ To check which constraints have been created, run: print(db.get_constraints()) ``` +## Using enums + +Memgraph's built-in [enum data type](https://memgraph.com/docs/fundamentals/data-types#enum) can be utilized on your GQLAlchemy OGM models. GQLAlchemy's enum implementation extends Python's [enum support](https://docs.python.org/3.11/library/enum.html). + +First, create an enum. + +```python +from enum import Enum + +class SubscriptionType(Enum): + FREE = 1 + BASIC = 2 + EXTENDED = 3 +``` + +Then, use the defined enum class in your model definition. Using the `Field` class, set the `enum` attribute to `True`. This will indicate that GQLAlchemy should treat the property value stored as a Memgraph enum. If the enum does not exist in the database, it will be created. + +```python +class User(Node): + id: str = Field(index=True, db=db) + username: str + subscription: SubscriptionType = Field(enum=True, db=db) +``` + +Enum types may be defined for properties on Nodes and Relationships. + +!!! info + If the `Field` class specification on the property isn't specified, or if `enum` is explicitly set to `False`, GQLAlchemy will use the `value` of the enum member when serializing to a Cypher query. A corresponding enum will not be created in the database. + + This functionality allows for flexiblity when using the Python `Enum` class, and would, for instance, respect an overridden `__getattribute__` method to customize the value passed to Cypher. + ## Full code example The above mentioned examples can be merged into a working code example which you can run. Here is the code: @@ -402,12 +433,19 @@ The above mentioned examples can be merged into a working code example which you ```python from gqlalchemy import Memgraph, Node, Relationship, Field from typing import Optional +from enum import Enum db = Memgraph() +class SubscriptionType(Enum): + FREE = 1 + BASIC = 2 + EXTENDED = 3 + class User(Node): id: str = Field(index=True, db=db) username: str = Field(exists=True, db=db) + subscription: SubscriptionType = Field(enum=True, db=db) class Streamer(User): id: str @@ -423,8 +461,8 @@ class ChatsWith(Relationship, type="CHATS_WITH"): class Speaks(Relationship, type="SPEAKS"): since: Optional[str] -john = User(id="1", username="John").save(db) -jane = Streamer(id="2", username="janedoe", followers=111).save(db) +john = User(id="1", username="John", subscription=SubscriptionType(1)).save(db) +jane = Streamer(id="2", username="janedoe", subscription=SubscriptionType(3), followers=111).save(db) language = Language(name="en").save(db) ChatsWith( @@ -449,7 +487,7 @@ try: streamer = Streamer(id="3").load(db=db) except: print("Creating new Streamer node in the database.") - streamer = Streamer(id="3", username="anne", followers=222).save(db=db) + streamer = Streamer(id="3", username="anne", subscription=SubscriptionType(2), followers=222).save(db=db) try: speaks = Speaks(_start_node_id=streamer._id, _end_node_id=language._id).load(db) diff --git a/gqlalchemy/__init__.py b/gqlalchemy/__init__.py index bccf728d..be797d06 100644 --- a/gqlalchemy/__init__.py +++ b/gqlalchemy/__init__.py @@ -20,6 +20,7 @@ MemgraphConstraintExists, MemgraphConstraintUnique, MemgraphIndex, + MemgraphEnum, MemgraphKafkaStream, MemgraphPulsarStream, MemgraphTrigger, diff --git a/gqlalchemy/models.py b/gqlalchemy/models.py index eaf8a146..aa0c80d2 100644 --- a/gqlalchemy/models.py +++ b/gqlalchemy/models.py @@ -16,9 +16,9 @@ from collections import defaultdict from dataclasses import dataclass from datetime import datetime, date, time, timedelta -from enum import Enum import json from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from enum import Enum, EnumMeta from pydantic.v1 import BaseModel, Extra, Field, PrivateAttr # noqa F401 @@ -59,6 +59,36 @@ def _format_timedelta(duration: timedelta) -> str: return f"P{days}DT{hours}H{minutes}M{remainder_sec}S" +class GraphEnum(ABC): + def __init__(self, enum): + + if not isinstance(enum, (Enum, EnumMeta)): + raise TypeError() + + self.enum = enum if isinstance(enum, Enum) else None + self.cls = enum.__class__ if isinstance(enum, Enum) else enum + + @property + def name(self): + return self.cls.__name__ + + @property + def members(self): + return self.cls.__members__ + + @abstractmethod + def _to_cypher(self): + pass + + +class MemgraphEnum(GraphEnum): + def _to_cypher(self): + return f"{{ {', '.join(self.cls._member_names_)} }}" + + def __repr__(self): + return f"" if self.enum is None else f"{self.name}::{self.enum.name}" + + class TriggerEventType: """An enum representing types of trigger events.""" @@ -308,6 +338,17 @@ class GraphObject(BaseModel): class Config: extra = Extra.allow + def __init__(self, **data): + for field in self.__class__.__fields__: + attrs = self.__class__.__fields__[field].field_info.extra + cls = self.__fields__[field].type_ + if issubclass(cls, Enum) and not attrs.get("enum", False): + value = data.get(field) + if isinstance(value, dict): + member = value.get("__value").split("::")[1] + data[field] = cls[member].value + super().__init__(**data) + def __init_subclass__(cls, type=None, label=None, labels=None, index=None, db=None): """Stores the subclass by type if type is specified, or by class name when instantiating a subclass. @@ -372,6 +413,8 @@ def escape_value( return repr(value) elif value_type == float: return repr(value) + elif isinstance(value, Enum): + return repr(MemgraphEnum(value)) elif isinstance(value, str): return json.dumps(value) elif isinstance(value, list): @@ -446,7 +489,11 @@ def _get_cypher_set_properties(self, variable_name: str) -> str: cypher_set_properties = [] for field in self.__fields__: attributes = self.__fields__[field].field_info.extra - value = getattr(self, field) + cls = self.__fields__[field].type_ + if issubclass(cls, Enum) and not attributes.get("enum", False): + value = getattr(self, field).value + else: + value = getattr(self, field) if value is not None and not attributes.get("on_disk", False): cypher_set_properties.append(f" SET {variable_name}.{field} = {self.escape_value(value)}") @@ -512,6 +559,9 @@ def get_base_labels() -> Set[str]: cls.labels = get_base_labels().union({cls.label}, kwargs.get("labels", set())) db = kwargs.get("db") + + cls.enums = None + if cls.index is True: if db is None: raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) @@ -522,12 +572,25 @@ def get_base_labels() -> Set[str]: for field in cls.__fields__: attrs = cls.__fields__[field].field_info.extra field_type = cls.__fields__[field].type_.__name__ + field_cls = cls.__fields__[field].type_ label = attrs.get("label", cls.label) skip_constraints = False if db is None: db = attrs.get("db") + if issubclass(field_cls, Enum) and attrs.get("enum", False): + if db is None: + raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) + if cls.enums is None: + cls.enums = db.get_enums() + enum_names = [x.name for x in cls.enums] + if field_cls.__name__ in enum_names: + existing = cls.enums[enum_names.index(field_cls.__name__)] + db.sync_enum(existing, MemgraphEnum(field_cls)) + else: + db.create_enum(MemgraphEnum(field_cls)) + for constraint in FieldAttrsConstants.list(): if constraint in attrs and db is None: base = field_in_superclass(field, constraint) @@ -663,6 +726,30 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 if name != "Relationship": cls.type = kwargs.get("type", name) + db = kwargs.get("db") + + cls.enums = None + + for field in cls.__fields__: + attrs = cls.__fields__[field].field_info.extra + field_type = cls.__fields__[field].type_.__name__ + field_cls = cls.__fields__[field].type_ + + if db is None: + db = attrs.get("db") + + if issubclass(field_cls, Enum) and attrs.get("enum", False): + if db is None: + raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) + if cls.enums is None: + cls.enums = db.get_enums() + enum_names = [x.name for x in cls.enums] + if field_type in enum_names: + existing = cls.enums[enum_names.index(field_type)] + db.sync_enum(existing, MemgraphEnum(field_cls)) + else: + db.create_enum(MemgraphEnum(field_cls)) + return cls diff --git a/gqlalchemy/vendors/database_client.py b/gqlalchemy/vendors/database_client.py index df907f67..182c9cda 100644 --- a/gqlalchemy/vendors/database_client.py +++ b/gqlalchemy/vendors/database_client.py @@ -17,12 +17,7 @@ from gqlalchemy.connection import Connection from gqlalchemy.exceptions import GQLAlchemyError -from gqlalchemy.models import ( - Constraint, - Index, - Node, - Relationship, -) +from gqlalchemy.models import Constraint, Index, GraphEnum, Node, Relationship class DatabaseClient(ABC): @@ -128,6 +123,30 @@ def ensure_constraints( for missing_constraint in new_constraints.difference(old_constraints): self.create_constraint(missing_constraint) + @abstractmethod + def create_enum(self, enum: GraphEnum) -> None: + pass + + @abstractmethod + def get_enums(self) -> List[GraphEnum]: + """Returns a list of all enums defined in the database.""" + pass + + @abstractmethod + def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: + """Ensures that database enum matches input enum.""" + pass + + @abstractmethod + def drop_enum(self, enum: GraphEnum) -> None: + """Drops a single enum in the database.""" + pass + + @abstractmethod + def drop_enums(self) -> None: + """Drops all enums in the database""" + pass + def drop_database(self): """Drops database by removing all nodes and edges.""" self.execute("MATCH (n) DETACH DELETE n;") diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index fb5fb7a9..ba2a3f82 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -31,6 +31,7 @@ MemgraphIndex, MemgraphStream, MemgraphTrigger, + MemgraphEnum, Node, Relationship, ) @@ -167,6 +168,32 @@ def get_constraints( ) return constraints + def create_enum(self, graph_enum: MemgraphEnum) -> None: + query = f"CREATE ENUM {graph_enum.name} VALUES {graph_enum._to_cypher()};" + self.execute(query) + + def get_enums(self) -> List[MemgraphEnum]: + """Returns a list of all enums defined in the database.""" + enums: List[MemgraphEnum] = [] + for result in self.execute_and_fetch("SHOW ENUMS;"): + enums.append(MemgraphEnum(Enum(result["Enum Name"], result["Enum Values"]))) + return enums + + def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None: + """Ensures that database enum matches input enum.""" + for value in new.members: + if value not in existing.members: + query = f"ALTER ENUM {existing.name} ADD VALUE {value};" + self.execute(query) + + def drop_enum(self, graph_enum: MemgraphEnum): + raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.name} is persisted in the database.") + + def drop_enums(self, graph_enums: List[MemgraphEnum]): + raise GQLAlchemyError( + f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database." + ) + def get_exists_constraints( self, ) -> List[MemgraphConstraintExists]: diff --git a/gqlalchemy/vendors/neo4j.py b/gqlalchemy/vendors/neo4j.py index a5b5aa58..5fc0b2c2 100644 --- a/gqlalchemy/vendors/neo4j.py +++ b/gqlalchemy/vendors/neo4j.py @@ -24,6 +24,7 @@ Neo4jConstraintExists, Neo4jConstraintUnique, Neo4jIndex, + GraphEnum, Node, Relationship, ) @@ -99,6 +100,23 @@ def ensure_indexes(self, indexes: List[Neo4jIndex]) -> None: for missing_index in new_indexes.difference(old_indexes): self.create_index(missing_index) + def create_enum(self, graph_enum: GraphEnum) -> None: + raise GQLAlchemyError(f"CREATE ENUM not yet implemented in Neo4j.") + + def get_enums(self) -> List[GraphEnum]: + """Returns a list of all enums defined in the database.""" + raise GQLAlchemyError(f"SHOW ENUMS not yet implemented in Neo4j.") + + def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: + """Ensures that database enum matches input enum.""" + raise GQLAlchemyError(f"ALTER ENUM not yet implemented in Neo4j.") + + def drop_enum(self, graph_enum: GraphEnum): + raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") + + def drop_enums(self, graph_enums: List[GraphEnum]): + raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") + def get_constraints( self, ) -> List[Union[Neo4jConstraintExists, Neo4jConstraintUnique]]: diff --git a/tests/ogm/test_custom_fields.py b/tests/ogm/test_custom_fields.py index d346300d..d4987e1a 100644 --- a/tests/ogm/test_custom_fields.py +++ b/tests/ogm/test_custom_fields.py @@ -13,10 +13,13 @@ from pydantic.v1 import Field +from enum import Enum + from gqlalchemy import ( MemgraphConstraintExists, MemgraphConstraintUnique, MemgraphIndex, + MemgraphEnum, Neo4jConstraintUnique, Neo4jIndex, Node, @@ -56,6 +59,19 @@ def test_create_index(memgraph): assert actual_constraints == [memgraph_index] +def test_create_graph_enum(memgraph): + enum1 = Enum("MgEnum", (("MEMBER1", "value1"), ("MEMBER2", "value2"), ("MEMBER3", "value3"))) + + class Node3(Node): + type: enum1 + + memgraph_enum = MemgraphEnum(enum1) + + actual_enums = memgraph.get_enums() + + assert actual_enums == [memgraph_enum] + + def test_create_constraint_unique_neo4j(neo4j): class Node2(Node): id: int = Field(db=neo4j)