Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,4 @@
from .sql.expression import type_coerce as type_coerce
from .sql.expression import within_group as within_group
from .sql.sqltypes import AutoString as AutoString
from .sql.sqltypes import IntEnum as IntEnum
33 changes: 32 additions & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, cast
import enum
from typing import Any, Optional, Type, TypeVar, cast

from sqlalchemy import types
from sqlalchemy.engine.interfaces import Dialect

_TIntEnum = TypeVar("_TIntEnum", bound=enum.IntEnum)


class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
Expand All @@ -14,3 +17,31 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect)


class IntEnum(types.TypeDecorator[Optional[_TIntEnum]]):
impl = types.SmallInteger
cache_ok = True

def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

# validate the input enum type
if not issubclass(enum_type, enum.IntEnum):
raise TypeError("Input must be enum.IntEnum")

self.enum_type = enum_type

def process_result_value(
self,
value: Optional[int],
dialect: Dialect,
) -> Optional[_TIntEnum]:
return None if (value is None) else self.enum_type(value)

def process_bind_param(
self,
value: Optional[_TIntEnum],
dialect: Dialect,
) -> Optional[int]:
return None if (value is None) else value.value
36 changes: 28 additions & 8 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_postgres_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]):
captured = capsys.readouterr()
assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out
assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out
assert "int_enum_field SMALLINT NOT NULL" in captured.out


def test_sqlite_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]):
Expand All @@ -52,6 +53,7 @@ def test_sqlite_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]):

captured = capsys.readouterr()
assert "enum_field VARCHAR(1) NOT NULL" in captured.out, captured
assert "int_enum_field SMALLINT NOT NULL" in captured.out, captured
assert "CREATE TYPE" not in captured.out


Expand All @@ -63,15 +65,22 @@ def test_json_schema_flat_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum1"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum1": {
"title": "MyEnum1",
"description": "An enumeration.",
"enum": ["A", "B"],
"type": "string",
}
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An enumeration.",
"enum": [1, 2],
"type": "integer",
},
},
}

Expand All @@ -84,15 +93,22 @@ def test_json_schema_inherit_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum2"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum2": {
"title": "MyEnum2",
"description": "An enumeration.",
"enum": ["C", "D"],
"type": "string",
}
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An enumeration.",
"enum": [1, 2],
"type": "integer",
},
},
}

Expand All @@ -105,10 +121,12 @@ def test_json_schema_flat_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum1"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}

Expand All @@ -121,9 +139,11 @@ def test_json_schema_inherit_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum2"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}
9 changes: 8 additions & 1 deletion tests/test_enums_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import uuid

from sqlmodel import Field, SQLModel
from sqlmodel import Field, IntEnum, SQLModel


class MyEnum1(str, enum.Enum):
Expand All @@ -14,14 +14,21 @@ class MyEnum2(str, enum.Enum):
D = "D"


class MyEnum3(enum.IntEnum):
E = 1
F = 2


class BaseModel(SQLModel):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum2
int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3))


class FlatModel(SQLModel, table=True):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum1
int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3))


class InheritModel(BaseModel, table=True):
Expand Down
Loading