Skip to content

Commit 961e19a

Browse files
committed
feat: handle SQLAlchemy types with missing python_type as typing.Any
fix: SQLModel.metadata as metadataref
1 parent f41d8b4 commit 961e19a

File tree

5 files changed

+104
-4
lines changed

5 files changed

+104
-4
lines changed

CHANGES.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
Version history
22
===============
33

4+
**UNRELEASED**
5+
- Handle SQLAlchemy type with unimplemented python_type as typing.Any (PR by @danplischke)
6+
- Fix SQLModel metadata reference (PR by @danplischke)
7+
48
**3.1.0**
59

610
- Type annotations for ARRAY column attributes now include the Python type of

src/sqlacodegen/generators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,11 @@ def render_python_type(column_type: TypeEngine[Any]) -> str:
12431243
if isinstance(column_type, DOMAIN):
12441244
python_type = column_type.data_type.python_type
12451245
else:
1246-
python_type = column_type.python_type
1246+
try:
1247+
python_type = column_type.python_type
1248+
except NotImplementedError:
1249+
self.add_literal_import("typing", "Any")
1250+
python_type = Any
12471251

12481252
python_type_name = python_type.__name__
12491253
python_type_module = python_type.__module__
@@ -1435,7 +1439,7 @@ def generate_base(self) -> None:
14351439
self.base = Base(
14361440
literal_imports=[],
14371441
declarations=[],
1438-
metadata_ref="",
1442+
metadata_ref="SQLModel.metadata",
14391443
)
14401444

14411445
def collect_imports(self, models: Iterable[Model]) -> None:

tests/test_generator_dataclass.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy.dialects.postgresql import UUID
5+
from sqlalchemy.dialects.postgresql import TSVECTOR, UUID
66
from sqlalchemy.engine import Engine
77
from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
88
from sqlalchemy.sql.expression import text
@@ -267,3 +267,35 @@ class Simple(Base):
267267
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
268268
""",
269269
)
270+
271+
272+
def test_tsvector_missing_python_type(generator: CodeGenerator) -> None:
273+
Table(
274+
"simple_tsvector",
275+
generator.metadata,
276+
Column("id", UUID, primary_key=True),
277+
Column("vector", TSVECTOR),
278+
)
279+
280+
validate_code(
281+
generator.generate(),
282+
"""\
283+
from typing import Any, Optional
284+
import typing
285+
import uuid
286+
287+
from sqlalchemy import UUID
288+
from sqlalchemy.dialects.postgresql import TSVECTOR
289+
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
290+
291+
class Base(MappedAsDataclass, DeclarativeBase):
292+
pass
293+
294+
295+
class SimpleTsvector(Base):
296+
__tablename__ = 'simple_tsvector'
297+
298+
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
299+
vector: Mapped[Optional[typing.Any]] = mapped_column(TSVECTOR)
300+
""",
301+
)

tests/test_generator_declarative.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from _pytest.fixtures import FixtureRequest
55
from sqlalchemy import BIGINT, PrimaryKeyConstraint
66
from sqlalchemy.dialects import postgresql
7-
from sqlalchemy.dialects.postgresql import JSON, JSONB
7+
from sqlalchemy.dialects.postgresql import JSON, JSONB, TSVECTOR
88
from sqlalchemy.engine import Engine
99
from sqlalchemy.schema import (
1010
CheckConstraint,
@@ -1706,3 +1706,34 @@ class TestDomainJson(Base):
17061706
foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False))
17071707
""",
17081708
)
1709+
1710+
1711+
def test_tsvector_missing_python_type(generator: CodeGenerator) -> None:
1712+
Table(
1713+
"test_tsvector",
1714+
generator.metadata,
1715+
Column("id", BIGINT, primary_key=True),
1716+
Column("vector", TSVECTOR()),
1717+
)
1718+
1719+
validate_code(
1720+
generator.generate(),
1721+
"""\
1722+
from typing import Any, Optional
1723+
import typing
1724+
1725+
from sqlalchemy import BigInteger
1726+
from sqlalchemy.dialects.postgresql import TSVECTOR
1727+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1728+
1729+
class Base(DeclarativeBase):
1730+
pass
1731+
1732+
1733+
class TestTsvector(Base):
1734+
__tablename__ = 'test_tsvector'
1735+
1736+
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
1737+
vector: Mapped[Optional[typing.Any]] = mapped_column(TSVECTOR)
1738+
""",
1739+
)

tests/test_generator_sqlmodel.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from _pytest.fixtures import FixtureRequest
55
from sqlalchemy import Uuid
6+
from sqlalchemy.dialects.postgresql import TSVECTOR
67
from sqlalchemy.engine import Engine
78
from sqlalchemy.schema import (
89
CheckConstraint,
@@ -204,3 +205,31 @@ class SimpleUuid(SQLModel, table=True):
204205
id: uuid.UUID = Field(sa_column=Column('id', Uuid, primary_key=True))
205206
""",
206207
)
208+
209+
210+
def test_tsvector_missing_python_type(generator: CodeGenerator) -> None:
211+
Table(
212+
"simple_tsvector",
213+
generator.metadata,
214+
Column("id", Uuid, primary_key=True),
215+
Column("search", TSVECTOR),
216+
)
217+
218+
validate_code(
219+
generator.generate(),
220+
"""\
221+
from typing import Any, Optional
222+
import typing
223+
import uuid
224+
225+
from sqlalchemy import Column, Uuid
226+
from sqlalchemy.dialects.postgresql import TSVECTOR
227+
from sqlmodel import Field, SQLModel
228+
229+
class SimpleTsvector(SQLModel, table=True):
230+
__tablename__ = 'simple_tsvector'
231+
232+
id: uuid.UUID = Field(sa_column=Column('id', Uuid, primary_key=True))
233+
search: Optional[typing.Any] = Field(default=None, sa_column=Column('search', TSVECTOR))
234+
""",
235+
)

0 commit comments

Comments
 (0)