Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
33 changes: 33 additions & 0 deletions databases/backends/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import typing

from sqlalchemy import ColumnDefault
from sqlalchemy.engine.default import DefaultDialect


class ConstructDefaultParamsMixin:
"""
A mixin to support column default values for insert queries for asyncpg,
aiomysql and aiosqlite
"""

prefetch: typing.List
dialect: DefaultDialect

def construct_params(
self,
params: typing.Optional[typing.Mapping] = None,
_group_number: typing.Any = None,
_check: bool = True,
) -> typing.Dict:
pd = super().construct_params(params, _group_number, _check) # type: ignore

for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)

return pd

def _exec_default(self, default: ColumnDefault) -> typing.Any:
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg
17 changes: 14 additions & 3 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")


class MySQLCompiler(ConstructDefaultParamsMixin, pymysql.dialect.statement_compiler):
pass


class MySQLBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = pymysql.dialect(paramstyle="pyformat")
self._dialect.supports_native_decimal = True
self._dialect = self._get_dialect()
self._pool = None

def _get_dialect(self) -> Dialect:
dialect = pymysql.dialect(paramstyle="pyformat")

dialect.statement_compiler = MySQLCompiler
dialect.supports_native_decimal = True

return dialect

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options

Expand Down
10 changes: 10 additions & 0 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@
from sqlalchemy.sql.schema import Column
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

_result_processors = {} # type: dict


class APGCompiler_psycopg2(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems strange to call this APGCompiler_psycopg2, given that the db backend used for postgres in databases is asyncpg, not psycopg2.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but that's because aiopg is just a wrapper around psycopg2, asyncpg is not.

ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler
):
pass


class PostgresBackend(DatabaseBackend):
def __init__(
Expand All @@ -28,6 +37,7 @@ def __init__(
def _get_dialect(self) -> Dialect:
dialect = pypostgresql.dialect(paramstyle="pyformat")

dialect.statement_compiler = APGCompiler_psycopg2
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
Expand Down
19 changes: 15 additions & 4 deletions databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,36 @@
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")


class SQLiteCompiler(ConstructDefaultParamsMixin, pysqlite.dialect.statement_compiler):
pass


class SQLiteBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = pysqlite.dialect(paramstyle="qmark")
# aiosqlite does not support decimals
self._dialect.supports_native_decimal = False
self._dialect = self._get_dialect()
self._pool = SQLitePool(self._database_url, **self._options)

def _get_dialect(self) -> Dialect:
dialect = pysqlite.dialect(paramstyle="qmark")

# aiosqlite does not support decimals
dialect.supports_native_decimal = False
dialect.statement_compiler = SQLiteCompiler

return dialect

async def connect(self) -> None:
pass
# assert self._pool is None, "DatabaseBackend is already running"
Expand Down
5 changes: 5 additions & 0 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def _build_query(
elif values:
return query.values(**values)

# for case when `table.insert()` called without `.values()` it has to be
# called to produce `insert_prefetch` for compiled query
if query.__visit_name__ == "insert":
return query.values()

return query


Expand Down
93 changes: 93 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def process_result_value(self, value, dialect):
sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)),
)

# Used to test column default values
default_values = sqlalchemy.Table(
"default_values",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42),
sqlalchemy.Column(
"with_callable_default",
sqlalchemy.String(length=100),
default=lambda: "default_value",
),
sqlalchemy.Column("without_default", sqlalchemy.Integer),
)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
Expand Down Expand Up @@ -651,6 +665,84 @@ async def test_json_field(database_url):
assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1}


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_with_scalar_default(database_url):
"""
Test insert with scalar column default value
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
values = {"without_default": 1}
await database.execute(query, values)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == 42
assert result["without_default"] == values["without_default"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_default_values_with_no_values_called(database_url):
"""
Test insert default values without calling ``values()`` on insert and
without passing ``values`` to ``execute()``.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
await database.execute(query)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == 42
assert result["without_default"] is None


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_default_values_with_overriden_default(database_url):
"""
Test if we provide value for a column having default value, the first one
should be set, not default one.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
values = {"with_default": 5}
await database.execute(query, values)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == values["with_default"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_callable_default(database_url):
"""
Test insert with column having callable default.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
await database.execute(query)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_callable_default"] == "default_value"


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_custom_field(database_url):
Expand Down Expand Up @@ -917,6 +1009,7 @@ async def run_database_queries():
async with database:

async def db_lookup():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left \n

await database.fetch_one("SELECT pg_sleep(1)")

await asyncio.gather(db_lookup(), db_lookup())
Expand Down