Skip to content

Commit 96fd08b

Browse files
committed
Restore renamed annotations in the original order without temporary renaming
b70d239 Broke annotations. This was because even annotations that didn't need renaming were renamed temporarily to `_new`. This made it impossible to use those fields in an expression. This fixes it by restoring the original behaviour, but fixing the ordering problem by restoring the renamed annotations in the original order.
1 parent 9738c69 commit 96fd08b

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

psqlextra/query.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,19 @@ def annotate(self, **annotations):
4040
the annotations are stored in an OrderedDict. Renaming only the
4141
conflicts will mess up the order.
4242
"""
43+
fields = {field.name: field for field in self.model._meta.get_fields()}
44+
4345
new_annotations = OrderedDict()
46+
4447
renames = {}
48+
4549
for name, value in annotations.items():
46-
new_name = "%s_new" % name
47-
new_annotations[new_name] = value
48-
renames[new_name] = name
50+
if name in fields:
51+
new_name = "%s_new" % name
52+
new_annotations[new_name] = value
53+
renames[new_name] = name
54+
else:
55+
new_annotations[name] = value
4956

5057
# run the base class's annotate function
5158
result = super().annotate(**new_annotations)

psqlextra/sql.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def rename_annotations(self, annotations) -> None:
4141
old name to the new name.
4242
"""
4343

44+
# safety check only, make sure there are no renames
45+
# left that cannot be mapped back to the original name
4446
for old_name, new_name in annotations.items():
4547
annotation = self.annotations.get(old_name)
46-
4748
if not annotation:
4849
raise SuspiciousOperation(
4950
(
@@ -52,13 +53,19 @@ def rename_annotations(self, annotations) -> None:
5253
).format(old_name=old_name, new_name=new_name)
5354
)
5455

55-
self.annotations[new_name] = annotation
56-
del self.annotations[old_name]
56+
# rebuild the annotations according to the original order
57+
new_annotations = dict()
58+
for old_name, annotation in self.annotations.items():
59+
new_name = annotations.get(old_name)
60+
new_annotations[new_name or old_name] = annotation
5761

58-
if self.annotation_select_mask:
62+
if new_name and self.annotation_select_mask:
5963
self.annotation_select_mask.discard(old_name)
6064
self.annotation_select_mask.add(new_name)
6165

66+
self.annotations.clear()
67+
self.annotations.update(new_annotations)
68+
6269
def add_fields(self, field_names: List[str], *args, **kwargs) -> bool:
6370
"""Adds the given (model) fields to the select set.
6471

tests/fake_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ def define_fake_partitioned_model(
8989
return model
9090

9191

92+
def get_fake_partitioned_model(
93+
fields=None, partitioning_options={}, meta_options={}
94+
):
95+
"""Defines a fake partitioned model and creates it in the database."""
96+
97+
model = define_fake_partitioned_model(
98+
fields, partitioning_options, meta_options
99+
)
100+
101+
with connection.schema_editor() as schema_editor:
102+
schema_editor.create_model(model)
103+
104+
return model
105+
106+
92107
def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}):
93108
"""Defines a fake model and creates it in the database."""
94109

tests/test_query.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from django.db import models
2-
from django.db.models import F
2+
from django.db.models import Case, F, Q, Value, When
33

44
from psqlextra.expressions import HStoreRef
55
from psqlextra.fields import HStoreField
@@ -75,6 +75,26 @@ def test_query_annotate_rename_order():
7575
assert list(qs.query.annotations.keys()) == ["value", "value_2"]
7676

7777

78+
def test_query_annotate_in_expression():
79+
"""Tests whether annotations can be used in expressions."""
80+
81+
model = get_fake_model({"name": models.CharField(max_length=10)})
82+
83+
model.objects.create(name="henk")
84+
85+
result = model.objects.annotate(
86+
real_name=F("name"),
87+
is_he_henk=Case(
88+
When(Q(real_name="henk"), then=Value("really henk")),
89+
default=Value("definitely not henk"),
90+
output_field=models.CharField(),
91+
),
92+
).first()
93+
94+
assert result.real_name == "henk"
95+
assert result.is_he_henk == "really henk"
96+
97+
7898
def test_query_hstore_value_update_f_ref():
7999
"""Tests whether F(..) expressions can be used in hstore values when
80100
performing update queries."""

0 commit comments

Comments
 (0)