From 228128618bd895ecad235d2215f4ad4e3232595d Mon Sep 17 00:00:00 2001 From: Mark Gensler Date: Thu, 18 Jul 2024 08:38:06 +0100 Subject: [PATCH] Fixed #35575 -- Added support for constraint validation on GeneratedFields. --- django/contrib/postgres/constraints.py | 12 +-- django/db/models/base.py | 35 +++++-- django/db/models/constraints.py | 81 ++++++++++++----- docs/releases/5.2.txt | 3 + tests/constraints/models.py | 41 +++++++++ tests/constraints/tests.py | 111 ++++++++++++++++++++++- tests/postgres_tests/test_constraints.py | 34 +++++++ 7 files changed, 273 insertions(+), 44 deletions(-) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 2701c4ba48..49124adc15 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint): ) replacements = {F(field): value for field, value in replacement_map.items()} lookups = [] - for idx, (expression, operator) in enumerate(self.expressions): + for expression, operator in self.expressions: if isinstance(expression, str): expression = F(expression) - if exclude: - if isinstance(expression, F): - if expression.name in exclude: - return - else: - for expr in expression.flatten(): - if isinstance(expr, F) and expr.name in exclude: - return + if exclude and self._expression_refs_exclude(model, expression, exclude): + return rhs_expression = expression.replace_expressions(replacements) if hasattr(expression, "get_expression_for_validation"): expression = expression.get_expression_for_validation() diff --git a/django/db/models/base.py b/django/db/models/base.py index d4b8bab963..a89ceafbef 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase): if exclude is None: exclude = set() meta = meta or self._meta - field_map = { - field.name: ( - value - if (value := getattr(self, field.attname)) - and hasattr(value, "resolve_expression") - else Value(value, field) - ) - for field in meta.local_concrete_fields - if field.name not in exclude and not field.generated - } + field_map = {} + generated_fields = [] + for field in meta.local_concrete_fields: + if field.name in exclude: + continue + if field.generated: + if any( + ref[0] in exclude + for ref in self._get_expr_references(field.expression) + ): + continue + generated_fields.append(field) + continue + value = getattr(self, field.attname) + if not value or not hasattr(value, "resolve_expression"): + value = Value(value, field) + field_map[field.name] = value if "pk" not in exclude: field_map["pk"] = Value(self.pk, meta.pk) + if generated_fields: + replacements = {F(name): value for name, value in field_map.items()} + for generated_field in generated_fields: + field_map[generated_field.name] = ExpressionWrapper( + generated_field.expression.replace_expressions(replacements), + generated_field.output_field, + ) + return field_map def prepare_database_save(self, field): diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 915ace5129..b5952def6a 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -68,6 +68,19 @@ class BaseConstraint: def remove_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") + @classmethod + def _expression_refs_exclude(cls, model, expression, exclude): + get_field = model._meta.get_field + for field_name, *__ in model._get_expr_references(expression): + if field_name in exclude: + return True + field = get_field(field_name) + if field.generated and cls._expression_refs_exclude( + model, field.expression, exclude + ): + return True + return False + def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): raise NotImplementedError("This method must be implemented by a subclass.") @@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint): queryset = model._default_manager.using(using) if self.fields: lookup_kwargs = {} + generated_field_names = [] for field_name in self.fields: if exclude and field_name in exclude: return field = model._meta.get_field(field_name) - lookup_value = getattr(instance, field.attname) - if ( - self.nulls_distinct is not False - and lookup_value is None - or ( - lookup_value == "" - and connections[ - using - ].features.interprets_empty_strings_as_nulls - ) - ): - # A composite constraint containing NULL value cannot cause - # a violation since NULL != NULL in SQL. - return - lookup_kwargs[field.name] = lookup_value - queryset = queryset.filter(**lookup_kwargs) + if field.generated: + if exclude and self._expression_refs_exclude( + model, field.expression, exclude + ): + return + generated_field_names.append(field.name) + else: + lookup_value = getattr(instance, field.attname) + if ( + self.nulls_distinct is not False + and lookup_value is None + or ( + lookup_value == "" + and connections[ + using + ].features.interprets_empty_strings_as_nulls + ) + ): + # A composite constraint containing NULL value cannot cause + # a violation since NULL != NULL in SQL. + return + lookup_kwargs[field.name] = lookup_value + lookup_args = [] + if generated_field_names: + field_expression_map = instance._get_field_expression_map( + meta=model._meta, exclude=exclude + ) + for field_name in generated_field_names: + expression = field_expression_map[field_name] + if self.nulls_distinct is False: + lhs = F(field_name) + condition = Q(Exact(lhs, expression)) | Q( + IsNull(lhs, True), IsNull(expression, True) + ) + lookup_args.append(condition) + else: + lookup_kwargs[field_name] = expression + queryset = queryset.filter(*lookup_args, **lookup_kwargs) else: # Ignore constraints with excluded fields. - if exclude: - for expression in self.expressions: - if hasattr(expression, "flatten"): - for expr in expression.flatten(): - if isinstance(expr, F) and expr.name in exclude: - return - elif isinstance(expression, F) and expression.name in exclude: - return + if exclude and any( + self._expression_refs_exclude(model, expression, exclude) + for expression in self.expressions + ): + return replacements = { F(field): value for field, value in instance._get_field_expression_map( diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index 7ae44cfd97..02a068e5af 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -215,6 +215,9 @@ Models methods such as :meth:`QuerySet.union()` unpredictable. +* Added support for validation of model constraints which use a + :class:`~django.db.models.GeneratedField`. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/constraints/models.py b/tests/constraints/models.py index 983d550502..829f671cdd 100644 --- a/tests/constraints/models.py +++ b/tests/constraints/models.py @@ -1,4 +1,5 @@ from django.db import models +from django.db.models.functions import Coalesce, Lower class Product(models.Model): @@ -28,6 +29,46 @@ class Product(models.Model): ] +class GeneratedFieldStoredProduct(models.Model): + name = models.CharField(max_length=255, null=True) + price = models.IntegerField(null=True) + discounted_price = models.IntegerField(null=True) + rebate = models.GeneratedField( + expression=Coalesce("price", 0) + - Coalesce("discounted_price", Coalesce("price", 0)), + output_field=models.IntegerField(), + db_persist=True, + ) + lower_name = models.GeneratedField( + expression=Lower(models.F("name")), + output_field=models.CharField(max_length=255, null=True), + db_persist=True, + ) + + class Meta: + required_db_features = {"supports_stored_generated_columns"} + + +class GeneratedFieldVirtualProduct(models.Model): + name = models.CharField(max_length=255, null=True) + price = models.IntegerField(null=True) + discounted_price = models.IntegerField(null=True) + rebate = models.GeneratedField( + expression=Coalesce("price", 0) + - Coalesce("discounted_price", Coalesce("price", 0)), + output_field=models.IntegerField(), + db_persist=False, + ) + lower_name = models.GeneratedField( + expression=Lower(models.F("name")), + output_field=models.CharField(max_length=255, null=True), + db_persist=False, + ) + + class Meta: + required_db_features = {"supports_virtual_generated_columns"} + + class UniqueConstraintProduct(models.Model): name = models.CharField(max_length=255) color = models.CharField(max_length=32, null=True) diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 350f05f2b8..9ca889ca6d 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError from django.db import IntegrityError, connection, models from django.db.models import F from django.db.models.constraints import BaseConstraint, UniqueConstraint -from django.db.models.functions import Abs, Lower, Upper +from django.db.models.functions import Abs, Lower, Sqrt, Upper from django.db.transaction import atomic from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import ignore_warnings @@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning from .models import ( ChildModel, ChildUniqueConstraintProduct, + GeneratedFieldStoredProduct, + GeneratedFieldVirtualProduct, JSONFieldModel, ModelWithDatabaseDefault, Product, @@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase): with self.assertRaisesMessage(ValidationError, msg): json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data)) + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_validate_generated_field_stored(self): + self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_validate_generated_field_virtual(self): + self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct) + + def assertGeneratedFieldIsValidated(self, model): + constraint = models.CheckConstraint( + condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate" + ) + constraint.validate(model, model(price=50, discounted_price=20)) + + invalid_product = model(price=1200, discounted_price=500) + msg = f"Constraint “{constraint.name}” is violated." + with self.assertRaisesMessage(ValidationError, msg): + constraint.validate(model, invalid_product) + + # Excluding referenced or generated fields should skip validation. + constraint.validate(model, invalid_product, exclude={"price"}) + constraint.validate(model, invalid_product, exclude={"rebate"}) + def test_check_deprecation(self): msg = "CheckConstraint.check is deprecated in favor of `.condition`." condition = models.Q(foo="bar") @@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase): exclude={"name"}, ) + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_validate_expression_generated_field_stored(self): + self.assertGeneratedFieldWithExpressionIsValidated( + model=GeneratedFieldStoredProduct + ) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_validate_expression_generated_field_virtual(self): + self.assertGeneratedFieldWithExpressionIsValidated( + model=GeneratedFieldVirtualProduct + ) + + def assertGeneratedFieldWithExpressionIsValidated(self, model): + constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt") + model.objects.create(price=100, discounted_price=84) + + valid_product = model(price=100, discounted_price=75) + constraint.validate(model, valid_product) + + invalid_product = model(price=20, discounted_price=4) + with self.assertRaisesMessage( + ValidationError, f"Constraint “{constraint.name}” is violated." + ): + constraint.validate(model, invalid_product) + + # Excluding referenced or generated fields should skip validation. + constraint.validate(model, invalid_product, exclude={"rebate"}) + constraint.validate(model, invalid_product, exclude={"price"}) + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_validate_fields_generated_field_stored(self): + self.assertGeneratedFieldWithFieldsIsValidated( + model=GeneratedFieldStoredProduct + ) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_validate_fields_generated_field_virtual(self): + self.assertGeneratedFieldWithFieldsIsValidated( + model=GeneratedFieldVirtualProduct + ) + + def assertGeneratedFieldWithFieldsIsValidated(self, model): + constraint = models.UniqueConstraint( + fields=["lower_name"], name="lower_name_unique" + ) + model.objects.create(name="Box") + constraint.validate(model, model(name="Case")) + + invalid_product = model(name="BOX") + msg = str(invalid_product.unique_error_message(model, ["lower_name"])) + with self.assertRaisesMessage(ValidationError, msg): + constraint.validate(model, invalid_product) + + # Excluding referenced or generated fields should skip validation. + constraint.validate(model, invalid_product, exclude={"lower_name"}) + constraint.validate(model, invalid_product, exclude={"name"}) + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_validate_fields_generated_field_stored_nulls_distinct(self): + self.assertGeneratedFieldNullsDistinctIsValidated( + model=GeneratedFieldStoredProduct + ) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_validate_fields_generated_field_virtual_nulls_distinct(self): + self.assertGeneratedFieldNullsDistinctIsValidated( + model=GeneratedFieldVirtualProduct + ) + + def assertGeneratedFieldNullsDistinctIsValidated(self, model): + constraint = models.UniqueConstraint( + fields=["lower_name"], + name="lower_name_unique_nulls_distinct", + nulls_distinct=False, + ) + model.objects.create(name=None) + valid_product = model(name="Box") + constraint.validate(model, valid_product) + + invalid_product = model(name=None) + msg = str(invalid_product.unique_error_message(model, ["lower_name"])) + with self.assertRaisesMessage(ValidationError, msg): + constraint.validate(model, invalid_product) + @skipUnlessDBFeature("supports_table_check_constraints") def test_validate_nullable_textfield_with_isnull_true(self): is_null_constraint = models.UniqueConstraint( diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index f571a96f35..ab5bf2bab1 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -14,6 +14,7 @@ from django.db.models import ( F, ForeignKey, Func, + GeneratedField, IntegerField, Model, Q, @@ -32,6 +33,7 @@ try: from django.contrib.postgres.constraints import ExclusionConstraint from django.contrib.postgres.fields import ( DateTimeRangeField, + IntegerRangeField, RangeBoundary, RangeOperators, ) @@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase): constraint.validate(RangesModel, RangesModel(ints=(51, 60))) constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"}) + @skipUnlessDBFeature("supports_stored_generated_columns") + @isolate_apps("postgres_tests") + def test_validate_generated_field_range_adjacent(self): + class RangesModelGeneratedField(Model): + ints = IntegerRangeField(blank=True, null=True) + ints_generated = GeneratedField( + expression=F("ints"), + output_field=IntegerRangeField(null=True), + db_persist=True, + ) + + with connection.schema_editor() as editor: + editor.create_model(RangesModelGeneratedField) + + constraint = ExclusionConstraint( + name="ints_adjacent", + expressions=[("ints_generated", RangeOperators.ADJACENT_TO)], + violation_error_code="custom_code", + violation_error_message="Custom error message.", + ) + RangesModelGeneratedField.objects.create(ints=(20, 50)) + + range_obj = RangesModelGeneratedField(ints=(3, 20)) + with self.assertRaisesMessage(ValidationError, "Custom error message."): + constraint.validate(RangesModelGeneratedField, range_obj) + + # Excluding referenced or generated field should skip validation. + constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"}) + constraint.validate( + RangesModelGeneratedField, range_obj, exclude={"ints_generated"} + ) + def test_validate_with_custom_code_and_condition(self): constraint = ExclusionConstraint( name="ints_adjacent",