diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe284bc84..7a15c06c61 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1435,12 +1435,22 @@ def get_unique_together_constraints(self, model): for unique_together in parent_class._meta.unique_together: yield unique_together, model._default_manager, [], None for constraint in parent_class._meta.constraints: - if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: + if isinstance(constraint, models.UniqueConstraint): if constraint.condition is None: condition_fields = [] else: - condition_fields = list(get_referenced_base_fields_from_q(constraint.condition)) - yield (constraint.fields, model._default_manager, condition_fields, constraint.condition) + condition_fields = list( + get_referenced_base_fields_from_q(constraint.condition) + ) + + required_fields = {*constraint.fields, *condition_fields} + if len(required_fields) > 1: + yield ( + constraint.fields, + model._default_manager, + condition_fields, + constraint.condition, + ) def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 15c4b91055..61da1e6a56 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -8,7 +8,9 @@ from django.db import models from django.utils.text import capfirst -from rest_framework.compat import postgres_fields +from rest_framework.compat import ( + get_referenced_base_fields_from_q, postgres_fields +) from rest_framework.validators import UniqueValidator NUMERIC_FIELD_TYPES = ( @@ -79,10 +81,16 @@ def get_unique_validators(field_name, model_field): unique_error_message = get_unique_error_message(model_field) queryset = model_field.model._default_manager for condition in conditions: - yield UniqueValidator( - queryset=queryset if condition is None else queryset.filter(condition), - message=unique_error_message + condition_fields = ( + get_referenced_base_fields_from_q(condition) + if condition is not None + else set() ) + if len(field_set | condition_fields) == 1: + yield UniqueValidator( + queryset=queryset if condition is None else queryset.filter(condition), + message=unique_error_message, + ) def get_field_kwargs(field_name, model_field): diff --git a/tests/test_validators.py b/tests/test_validators.py index c594eecbe5..d6d74e67f0 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -170,6 +170,24 @@ class Meta: unique_together = ('race_name', 'position') +class ConditionUniquenessTogetherModel(models.Model): + """ + Used to ensure that unique constraints with single fields but at least one other + distinct condition field are included when checking unique_together constraints. + """ + race_name = models.CharField(max_length=100) + position = models.IntegerField() + + class Meta: + constraints = [ + models.UniqueConstraint( + name="condition_uniqueness_together_model_race_name", + fields=('race_name',), + condition=models.Q(position__lte=1) + ) + ] + + class UniquenessTogetherSerializer(serializers.ModelSerializer): class Meta: model = UniquenessTogetherModel @@ -182,6 +200,12 @@ class Meta: fields = '__all__' +class ConditionUniquenessTogetherSerializer(serializers.ModelSerializer): + class Meta: + model = ConditionUniquenessTogetherModel + fields = '__all__' + + class TestUniquenessTogetherValidation(TestCase): def setUp(self): self.instance = UniquenessTogetherModel.objects.create( @@ -222,6 +246,22 @@ def test_is_not_unique_together(self): ] } + def test_is_not_unique_together_condition_based(self): + """ + Failing unique together validation should result in non field errors when a condition-based + unique together constraint is violated. + """ + ConditionUniquenessTogetherModel.objects.create(race_name='example', position=1) + + data = {'race_name': 'example', 'position': 1} + serializer = ConditionUniquenessTogetherSerializer(data=data) + assert not serializer.is_valid() + assert serializer.errors == { + 'non_field_errors': [ + 'The fields race_name must make a unique set.' + ] + } + def test_is_unique_together(self): """ In a unique together validation, one field may be non-unique @@ -235,6 +275,21 @@ def test_is_unique_together(self): 'position': 2 } + def test_unique_together_condition_based(self): + """ + In a unique together validation, one field may be non-unique + so long as the set as a whole is unique. + """ + ConditionUniquenessTogetherModel.objects.create(race_name='example', position=1) + + data = {'race_name': 'other', 'position': 1} + serializer = ConditionUniquenessTogetherSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'race_name': 'other', + 'position': 1 + } + def test_updated_instance_excluded_from_unique_together(self): """ When performing an update, the existing instance does not count @@ -248,6 +303,21 @@ def test_updated_instance_excluded_from_unique_together(self): 'position': 1 } + def test_updated_instance_excluded_from_unique_together_condition_based(self): + """ + When performing an update, the existing instance does not count + as a match against uniqueness. + """ + ConditionUniquenessTogetherModel.objects.create(race_name='example', position=1) + + data = {'race_name': 'example', 'position': 0} + serializer = ConditionUniquenessTogetherSerializer(self.instance, data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'race_name': 'example', + 'position': 0 + } + def test_unique_together_is_required(self): """ In a unique together validation, all fields are required. @@ -699,20 +769,17 @@ class Meta: def test_single_field_uniq_validators(self): """ UniqueConstraint with single field must be transformed into - field's UniqueValidator + field's UniqueValidator if no distinct condition fields exist (else UniqueTogetherValidator) """ # Django 5 includes Max and Min values validators for IntegerField extra_validators_qty = 2 if django_version[0] >= 5 else 0 serializer = UniqueConstraintSerializer() - assert len(serializer.validators) == 2 + assert len(serializer.validators) == 4 validators = serializer.fields['global_id'].validators assert len(validators) == 1 + extra_validators_qty assert validators[0].queryset == UniqueConstraintModel.objects - - validators = serializer.fields['fancy_conditions'].validators - assert len(validators) == 2 + extra_validators_qty ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators if hasattr(v, "queryset")} - assert ids_in_qs == {frozenset([1]), frozenset([3])} + assert ids_in_qs == {frozenset([1, 2, 3])} def test_nullable_unique_constraint_fields_are_not_required(self): serializer = UniqueConstraintNullableSerializer(data={'title': 'Bob'})