From 9c62c421265b22da54c0e70eba64ea2f93eff435 Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 17:31:31 +0300 Subject: [PATCH 1/6] Add operation_type support --- docs/source/schemas.rst | 2 +- psqlextra/compiler.py | 75 +++++++++++++++++++++++++---------- psqlextra/query.py | 83 +++++++++++++++++++++++++++++---------- psqlextra/types.py | 11 ++++++ tests/test_on_conflict.py | 32 +++++++++++++++ tests/test_upsert.py | 38 ++++++++++++++++++ 6 files changed, 200 insertions(+), 41 deletions(-) diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst index 01fdd345..8abc4a90 100644 --- a/docs/source/schemas.rst +++ b/docs/source/schemas.rst @@ -141,7 +141,7 @@ Deleting a schema Any schema can be dropped, including ones not created by :class:`~psqlextra.schema.PostgresSchema`. -The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` erorr will be raised if you attempt to drop the ``public`` schema. +The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` error will be raised if you attempt to drop the ``public`` schema. .. warning:: diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 88a65e9a..fa7c3811 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -3,22 +3,22 @@ import sys from collections.abc import Iterable -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union import django from django.conf import settings from django.core.exceptions import SuspiciousOperation -from django.db.models import Expression, Model, Q +from django.db.models import Expression, Field, Model, Q from django.db.models.fields.related import RelatedField from django.db.models.sql import compiler as django_compiler from django.db.utils import ProgrammingError from .expressions import HStoreValue -from .types import ConflictAction +from .types import ConflictAction, UpsertOperation -def append_caller_to_sql(sql): +def append_caller_to_sql(sql) -> str: """Append the caller to SQL queries. Adds the calling file and function as an SQL comment to each query. @@ -162,26 +162,43 @@ def as_sql(self, *args, **kwargs): class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" + RETURNING_OPERATION_TYPE_CLAUSE = ( + f"CASE WHEN xmax::text::int > 0 " + f"THEN '{UpsertOperation.UPDATE.value}' " + f"ELSE '{UpsertOperation.INSERT.value}' END" + ) + RETURNING_OPERATION_TYPE_FIELD = "_operation_type" + def __init__(self, *args, **kwargs): """Initializes a new instance of :see:PostgresInsertOnConflictCompiler.""" super().__init__(*args, **kwargs) self.qn = self.connection.ops.quote_name - def as_sql(self, return_id=False, *args, **kwargs): + def as_sql( + self, + return_id: bool = False, + return_operation_type: bool = False, + *args, + **kwargs, + ): """Builds the SQL INSERT statement.""" queries = [ - self._rewrite_insert(sql, params, return_id) + self._rewrite_insert(sql, params, return_id, return_operation_type) for sql, params in super().as_sql(*args, **kwargs) ] return queries - def execute_sql(self, return_id=False): + def execute_sql( + self, + return_id: bool = False, + return_operation_type: bool = False, + ) -> List[dict]: # execute all the generate queries with self.connection.cursor() as cursor: rows = [] - for sql, params in self.as_sql(return_id): + for sql, params in self.as_sql(return_id, return_operation_type): cursor.execute(sql, params) try: rows.extend(cursor.fetchall()) @@ -199,7 +216,13 @@ def execute_sql(self, return_id=False): for row in rows ] - def _rewrite_insert(self, sql, params, return_id=False): + def _rewrite_insert( + self, + sql: str, + params: list, + return_id: bool = False, + return_operation_type: bool = False, + ) -> Tuple[str, list]: """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -221,16 +244,27 @@ def _rewrite_insert(self, sql, params, return_id=False): returning = ( self.qn(self.query.model._meta.pk.attname) if return_id else "*" ) + # Return metadata about the row, so we can tell if it was inserted or + # updated by checking the `xmax` Postgres system column. + if return_operation_type: + returning += f", ({self.RETURNING_OPERATION_TYPE_CLAUSE}) AS {self.RETURNING_OPERATION_TYPE_FIELD}" (sql, params) = self._rewrite_insert_on_conflict( - sql, params, self.query.conflict_action.value, returning + sql, + params, + self.query.conflict_action.value, + returning, ) return append_caller_to_sql(sql), params def _rewrite_insert_on_conflict( - self, sql, params, conflict_action: ConflictAction, returning - ): + self, + sql: str, + params: list, + conflict_action: ConflictAction, + returning: str, + ) -> Tuple[str, list]: """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT' clause.""" @@ -256,7 +290,7 @@ def _rewrite_insert_on_conflict( rewritten_sql += f" DO {conflict_action}" - if conflict_action == "UPDATE": + if conflict_action == ConflictAction.UPDATE.value: rewritten_sql += f" SET {update_columns}" if update_condition: @@ -270,7 +304,7 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) - def _build_on_conflict_clause(self): + def _build_on_conflict_clause(self) -> str: if django.VERSION >= (2, 2): from django.db.models.constraints import BaseConstraint from django.db.models.indexes import Index @@ -285,7 +319,7 @@ def _build_on_conflict_clause(self): conflict_target = self._build_conflict_target() return f"ON CONFLICT {conflict_target}" - def _build_conflict_target(self): + def _build_conflict_target(self) -> str: """Builds the `conflict_target` for the ON CONFLICT clause.""" if not isinstance(self.query.conflict_target, Iterable): @@ -304,7 +338,7 @@ def _build_conflict_target(self): return self._build_conflict_target_by_fields() - def _build_conflict_target_by_fields(self): + def _build_conflict_target_by_fields(self) -> str: """Builds the `conflict_target` for the ON CONFLICT clauses by matching the fields specified in the specified conflict target against the model's fields. @@ -329,7 +363,7 @@ def _build_conflict_target_by_fields(self): return "(%s)" % ",".join(conflict_target) - def _build_conflict_target_by_index(self): + def _build_conflict_target_by_index(self) -> Optional[str]: """Builds the `conflict_target` for the ON CONFLICT clause by trying to find an index that matches the specified conflict target on the query. @@ -353,7 +387,7 @@ def _build_conflict_target_by_index(self): stmt = matching_index.create_sql(self.query.model, schema_editor) return "(%s)" % stmt.parts["columns"] - def _get_model_field(self, name: str): + def _get_model_field(self, name: str) -> Optional[Field]: """Gets the field on a model with the specified name. Arguments: @@ -432,7 +466,8 @@ def _format_field_value(self, field_name) -> str: ) def _compile_expression( - self, expression: Union[Expression, Q, str] + self, + expression: Union[Expression, Q, str], ) -> Tuple[str, Union[tuple, list]]: """Compiles an expression, Q object or raw SQL string into SQL and tuple of parameters.""" @@ -452,7 +487,7 @@ def _compile_expression( return expression, tuple() - def _assert_valid_field(self, field_name: str): + def _assert_valid_field(self, field_name: str) -> None: """Asserts that a field with the specified name exists on the model and raises :see:SuspiciousOperation if it does not.""" diff --git a/psqlextra/query.py b/psqlextra/query.py index b3feec1d..ecc361df 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -2,6 +2,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, + Any, Dict, Generic, Iterable, @@ -17,7 +18,11 @@ from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED -from .sql import PostgresInsertQuery, PostgresQuery +from .sql import ( + PostgresInsertOnConflictCompiler, + PostgresInsertQuery, + PostgresQuery, +) from .types import ConflictAction if TYPE_CHECKING: @@ -139,7 +144,8 @@ def bulk_insert( rows: Iterable[dict], return_model: bool = False, using: Optional[str] = None, - ): + return_operation_type: bool = False, + ) -> Union[List[Dict], List[TModel]]: """Creates multiple new records in the database. This allows specifying custom conflict behavior using .on_conflict(). @@ -158,6 +164,13 @@ def bulk_insert( Optional name of the database connection to use for this query. + return_operation_type (default: False): + If the operation type should be returned for each row. + This is only supported when return_model is False. + The operation_type is either 'INSERT' or 'UPDATE' and + the value will be contained in the '_operation_type' key + of the returned dict. + Returns: A list of either the dicts of the rows inserted, including the pk or the models of the rows inserted with defaults for any fields not specified @@ -195,16 +208,28 @@ def is_empty(r): deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) - objs = compiler.execute_sql(return_id=not return_model) + objs = compiler.execute_sql( + return_id=not return_model, + return_operation_type=return_operation_type and not return_model, + ) if return_model: - return [ - self._create_model_instance(dict(row, **obj), compiler.using) - for row, obj in zip(deduped_rows, objs) - ] + models = [] + for row, obj in zip(deduped_rows, objs): + models.append( + self._create_model_instance( + dict(row, **obj), + compiler.using, + ) + ) + return models return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] - def insert(self, using: Optional[str] = None, **fields): + def insert( + self, + using: Optional[str] = None, + **fields: Any, + ) -> Optional[int]: """Creates a new record in the database. This allows specifying custom conflict behavior using .on_conflict(). @@ -238,7 +263,7 @@ def insert(self, using: Optional[str] = None, **fields): # no special action required, use the standard Django create(..) return super().create(**fields).pk - def insert_and_get(self, using: Optional[str] = None, **fields): + def insert_and_get(self, using: Optional[str] = None, **fields: Any) -> Optional[TModel]: """Creates a new record in the database and then gets the entire row. This allows specifying custom conflict behavior using .on_conflict(). @@ -261,7 +286,7 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False) + rows = compiler.execute_sql(return_id=False, return_operation_type=False) if not rows: return None @@ -293,7 +318,7 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - ) -> int: + ) -> Optional[int]: """Creates a new record or updates the existing one with the specified data. @@ -336,7 +361,7 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - ): + ) -> Optional[TModel]: """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -381,7 +406,8 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - ): + return_operation_type: bool = False, + ) -> Union[List[Dict], List[TModel]]: """Creates a set of new records or updates the existing ones with the specified data. @@ -407,6 +433,13 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. + return_operation_type (default: False): + If the operation type should be returned for each row. + This is only supported when return_model is False. + The operation_type is either 'INSERT' or 'UPDATE' and + the value will be contained in the '_operation_type' key + of the returned dict. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -418,15 +451,23 @@ def bulk_upsert( index_predicate=index_predicate, update_condition=update_condition, ) - return self.bulk_insert(rows, return_model, using=using) + return self.bulk_insert( + rows, + return_model, + using=using, + return_operation_type=return_operation_type, + ) def _create_model_instance( - self, field_values: dict, using: str, apply_converters: bool = True - ): + self, + field_values: dict, + using: str, + apply_converters: bool = True + ) -> TModel: """Creates a new instance of the model with the specified field. - Use this after the row was inserted into the database. The new - instance will marked as "saved". + Use this after the row was inserted/updated into the database. The new + instance will be marked as "saved". """ converted_field_values = field_values.copy() @@ -459,8 +500,10 @@ def _create_model_instance( return instance def _build_insert_compiler( - self, rows: Iterable[Dict], using: Optional[str] = None - ): + self, + rows: Iterable[Dict], + using: Optional[str] = None, + ) -> PostgresInsertOnConflictCompiler: """Builds the SQL compiler for a insert query. Arguments: diff --git a/psqlextra/types.py b/psqlextra/types.py index a325fd9e..72fe20e3 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -29,6 +29,17 @@ def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] +class UpsertOperation(StrEnum): + """Possible operations to take on an upsert.""" + + INSERT = "INSERT" + UPDATE = "UPDATE" + + @classmethod + def all(cls) -> List["UpsertOperation"]: + return [choice for choice in cls] + + class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for table partitioning.""" diff --git a/tests/test_on_conflict.py b/tests/test_on_conflict.py index 02eda62f..4b876ea6 100644 --- a/tests/test_on_conflict.py +++ b/tests/test_on_conflict.py @@ -9,6 +9,7 @@ from psqlextra.fields import HStoreField from psqlextra.models import PostgresModel from psqlextra.query import ConflictAction +from psqlextra.types import UpsertOperation from .fake_model import get_fake_model @@ -397,6 +398,37 @@ def test_bulk_return(): assert obj["id"] == index +def test_bulk_return_with_operation_type(): + """Tests if the _operation_type is properly returned from 'bulk_insert'.""" + + model = get_fake_model( + { + "id": models.BigAutoField(primary_key=True), + "name": models.CharField(max_length=255, unique=True), + } + ) + + rows = [dict(name="John Smith"), dict(name="Jane Doe")] + + objs = model.objects.on_conflict( + ["name"], ConflictAction.UPDATE + ).bulk_insert(rows, return_operation_type=True) + + for index, obj in enumerate(objs, 1): + assert obj["id"] == index + assert obj["_operation_type"] == UpsertOperation.INSERT.value + + # Add objects again, update should return the same ids + # as we're just updating. + objs = model.objects.on_conflict( + ["name"], ConflictAction.UPDATE + ).bulk_insert(rows, return_operation_type=True) + + for index, obj in enumerate(objs, 1): + assert obj["id"] == index + assert obj["_operation_type"] == UpsertOperation.UPDATE.value + + @pytest.mark.parametrize("conflict_action", ConflictAction.all()) def test_bulk_return_models(conflict_action): """Tests whether models are returned instead of dictionaries when diff --git a/tests/test_upsert.py b/tests/test_upsert.py index b9176da1..177a7952 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -8,6 +8,7 @@ from psqlextra.expressions import ExcludedCol from psqlextra.fields import HStoreField from psqlextra.query import ConflictAction +from psqlextra.types import UpsertOperation from .fake_model import get_fake_model @@ -259,6 +260,43 @@ def test_upsert_bulk_no_rows(): ) +def test_upsert_bulk_returns_operation_type(): + """Tests whether bulk_upsert works properly with the return_operation_type flag.""" + + model = get_fake_model( + { + "first_name": models.CharField( + max_length=255, null=True, unique=True + ), + "last_name": models.CharField(max_length=255, null=True), + } + ) + + rows = model.objects.bulk_upsert( + conflict_target=["first_name"], + rows=[ + dict(first_name="Swen", last_name="Kooij"), + dict(first_name="Henk", last_name="Test"), + ], + return_operation_type=True, + ) + + for row in rows: + assert row["_operation_type"] == UpsertOperation.INSERT.value + + rows = model.objects.bulk_upsert( + conflict_target=["first_name"], + rows=[ + dict(first_name="Swen", last_name="Test"), + dict(first_name="Henk", last_name="Kooij"), + ], + return_operation_type=True, + ) + + for row in rows: + assert row["_operation_type"] == UpsertOperation.UPDATE.value + + def test_bulk_upsert_return_models(): """Tests whether models are returned instead of dictionaries when specifying the return_model=True argument.""" From c801bfbad17de74c1584e65c65399e22d7d19b9d Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 17:35:08 +0300 Subject: [PATCH 2/6] Undo unneded change --- psqlextra/query.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index ecc361df..6f5b1a83 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -213,15 +213,10 @@ def is_empty(r): return_operation_type=return_operation_type and not return_model, ) if return_model: - models = [] - for row, obj in zip(deduped_rows, objs): - models.append( - self._create_model_instance( - dict(row, **obj), - compiler.using, - ) - ) - return models + return [ + self._create_model_instance(dict(row, **obj), compiler.using) + for row, obj in zip(deduped_rows, objs) + ] return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] From 039486cf9b27a8755c8ddfe51ba1cc0fbe42fa27 Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 18:00:20 +0300 Subject: [PATCH 3/6] Fix black issues --- psqlextra/query.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index 6f5b1a83..3ac6176a 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -258,7 +258,9 @@ def insert( # no special action required, use the standard Django create(..) return super().create(**fields).pk - def insert_and_get(self, using: Optional[str] = None, **fields: Any) -> Optional[TModel]: + def insert_and_get( + self, using: Optional[str] = None, **fields: Any + ) -> Optional[TModel]: """Creates a new record in the database and then gets the entire row. This allows specifying custom conflict behavior using .on_conflict(). @@ -281,7 +283,9 @@ def insert_and_get(self, using: Optional[str] = None, **fields: Any) -> Optional return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False, return_operation_type=False) + rows = compiler.execute_sql( + return_id=False, return_operation_type=False + ) if not rows: return None @@ -454,10 +458,7 @@ def bulk_upsert( ) def _create_model_instance( - self, - field_values: dict, - using: str, - apply_converters: bool = True + self, field_values: dict, using: str, apply_converters: bool = True ) -> TModel: """Creates a new instance of the model with the specified field. From daaa32fe00e23c73c73e035160e93a768cab66ce Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 18:19:44 +0300 Subject: [PATCH 4/6] Fix docstrings --- psqlextra/query.py | 4 ++-- tests/test_upsert.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index 3ac6176a..3854b7d8 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -462,8 +462,8 @@ def _create_model_instance( ) -> TModel: """Creates a new instance of the model with the specified field. - Use this after the row was inserted/updated into the database. The new - instance will be marked as "saved". + Use this after the row was inserted/updated into the database. + The new instance will be marked as "saved". """ converted_field_values = field_values.copy() diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 177a7952..2ba89e47 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -261,7 +261,8 @@ def test_upsert_bulk_no_rows(): def test_upsert_bulk_returns_operation_type(): - """Tests whether bulk_upsert works properly with the return_operation_type flag.""" + """Tests whether bulk_upsert works properly with the return_operation_type + flag.""" model = get_fake_model( { From 3fa881752acf39a063c153a2e80a5b42535c0cff Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 18:44:46 +0300 Subject: [PATCH 5/6] Fix ci issues by removing typing There are pre-existing typing issues in the project that seems to be ignored because the functions don't have typing hints. --- psqlextra/compiler.py | 32 ++++++++++++-------------------- psqlextra/query.py | 24 +++++++----------------- psqlextra/types.py | 4 ---- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index fa7c3811..bc529d7c 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -3,7 +3,7 @@ import sys from collections.abc import Iterable -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import django @@ -177,8 +177,8 @@ def __init__(self, *args, **kwargs): def as_sql( self, - return_id: bool = False, - return_operation_type: bool = False, + return_id=False, + return_operation_type=False, *args, **kwargs, ): @@ -190,11 +190,7 @@ def as_sql( return queries - def execute_sql( - self, - return_id: bool = False, - return_operation_type: bool = False, - ) -> List[dict]: + def execute_sql(self, return_id=False, return_operation_type=False): # execute all the generate queries with self.connection.cursor() as cursor: rows = [] @@ -217,12 +213,8 @@ def execute_sql( ] def _rewrite_insert( - self, - sql: str, - params: list, - return_id: bool = False, - return_operation_type: bool = False, - ) -> Tuple[str, list]: + self, sql, params, return_id=False, return_operation_type=False + ): """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -304,7 +296,7 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) - def _build_on_conflict_clause(self) -> str: + def _build_on_conflict_clause(self): if django.VERSION >= (2, 2): from django.db.models.constraints import BaseConstraint from django.db.models.indexes import Index @@ -319,7 +311,7 @@ def _build_on_conflict_clause(self) -> str: conflict_target = self._build_conflict_target() return f"ON CONFLICT {conflict_target}" - def _build_conflict_target(self) -> str: + def _build_conflict_target(self): """Builds the `conflict_target` for the ON CONFLICT clause.""" if not isinstance(self.query.conflict_target, Iterable): @@ -338,7 +330,7 @@ def _build_conflict_target(self) -> str: return self._build_conflict_target_by_fields() - def _build_conflict_target_by_fields(self) -> str: + def _build_conflict_target_by_fields(self): """Builds the `conflict_target` for the ON CONFLICT clauses by matching the fields specified in the specified conflict target against the model's fields. @@ -363,7 +355,7 @@ def _build_conflict_target_by_fields(self) -> str: return "(%s)" % ",".join(conflict_target) - def _build_conflict_target_by_index(self) -> Optional[str]: + def _build_conflict_target_by_index(self): """Builds the `conflict_target` for the ON CONFLICT clause by trying to find an index that matches the specified conflict target on the query. @@ -418,7 +410,7 @@ def _get_model_field(self, name: str) -> Optional[Field]: return None - def _format_field_name(self, field_name) -> str: + def _format_field_name(self, field_name): """Formats a field's name for usage in SQL. Arguments: @@ -433,7 +425,7 @@ def _format_field_name(self, field_name) -> str: field = self._get_model_field(field_name) return self.qn(field.column) - def _format_field_value(self, field_name) -> str: + def _format_field_value(self, field_name): """Formats a field's value for usage in SQL. Arguments: diff --git a/psqlextra/query.py b/psqlextra/query.py index 3854b7d8..e1345410 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -2,7 +2,6 @@ from itertools import chain from typing import ( TYPE_CHECKING, - Any, Dict, Generic, Iterable, @@ -19,7 +18,6 @@ from django.db.models.fields import NOT_PROVIDED from .sql import ( - PostgresInsertOnConflictCompiler, PostgresInsertQuery, PostgresQuery, ) @@ -145,7 +143,7 @@ def bulk_insert( return_model: bool = False, using: Optional[str] = None, return_operation_type: bool = False, - ) -> Union[List[Dict], List[TModel]]: + ): """Creates multiple new records in the database. This allows specifying custom conflict behavior using .on_conflict(). @@ -220,11 +218,7 @@ def is_empty(r): return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] - def insert( - self, - using: Optional[str] = None, - **fields: Any, - ) -> Optional[int]: + def insert(self, using: Optional[str] = None, **fields): """Creates a new record in the database. This allows specifying custom conflict behavior using .on_conflict(). @@ -258,9 +252,7 @@ def insert( # no special action required, use the standard Django create(..) return super().create(**fields).pk - def insert_and_get( - self, using: Optional[str] = None, **fields: Any - ) -> Optional[TModel]: + def insert_and_get(self, using: Optional[str] = None, **fields): """Creates a new record in the database and then gets the entire row. This allows specifying custom conflict behavior using .on_conflict(). @@ -406,7 +398,7 @@ def bulk_upsert( using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, return_operation_type: bool = False, - ) -> Union[List[Dict], List[TModel]]: + ): """Creates a set of new records or updates the existing ones with the specified data. @@ -459,7 +451,7 @@ def bulk_upsert( def _create_model_instance( self, field_values: dict, using: str, apply_converters: bool = True - ) -> TModel: + ): """Creates a new instance of the model with the specified field. Use this after the row was inserted/updated into the database. @@ -496,10 +488,8 @@ def _create_model_instance( return instance def _build_insert_compiler( - self, - rows: Iterable[Dict], - using: Optional[str] = None, - ) -> PostgresInsertOnConflictCompiler: + self, rows: Iterable[Dict], using: Optional[str] = None + ): """Builds the SQL compiler for a insert query. Arguments: diff --git a/psqlextra/types.py b/psqlextra/types.py index 72fe20e3..1c007ba9 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -35,10 +35,6 @@ class UpsertOperation(StrEnum): INSERT = "INSERT" UPDATE = "UPDATE" - @classmethod - def all(cls) -> List["UpsertOperation"]: - return [choice for choice in cls] - class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for From 30affc00e7afd8448714cb11c58231c4ea70b874 Mon Sep 17 00:00:00 2001 From: Renzo Manganiello Date: Mon, 10 Jul 2023 20:28:10 +0300 Subject: [PATCH 6/6] Fix isort --- psqlextra/query.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index e1345410..aeff8c0e 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -17,10 +17,7 @@ from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED -from .sql import ( - PostgresInsertQuery, - PostgresQuery, -) +from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction if TYPE_CHECKING: