mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #35638 -- Updated validate_constraints to consider db_default.
This commit is contained in:
		
				
					committed by
					
						 Sarah Boyce
						Sarah Boyce
					
				
			
			
				
	
			
			
			
						parent
						
							91a038754b
						
					
				
				
					commit
					509763c799
				
			| @@ -1250,9 +1250,41 @@ class Star(Expression): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseDefault(Expression): | class DatabaseDefault(Expression): | ||||||
|     """Placeholder expression for the database default in an insert query.""" |     """ | ||||||
|  |     Expression to use DEFAULT keyword during insert otherwise the underlying expression. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, expression, output_field=None): | ||||||
|  |         super().__init__(output_field) | ||||||
|  |         self.expression = expression | ||||||
|  |  | ||||||
|  |     def get_source_expressions(self): | ||||||
|  |         return [self.expression] | ||||||
|  |  | ||||||
|  |     def set_source_expressions(self, exprs): | ||||||
|  |         (self.expression,) = exprs | ||||||
|  |  | ||||||
|  |     def resolve_expression( | ||||||
|  |         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||||
|  |     ): | ||||||
|  |         resolved_expression = self.expression.resolve_expression( | ||||||
|  |             query=query, | ||||||
|  |             allow_joins=allow_joins, | ||||||
|  |             reuse=reuse, | ||||||
|  |             summarize=summarize, | ||||||
|  |             for_save=for_save, | ||||||
|  |         ) | ||||||
|  |         # Defaults used outside an INSERT context should resolve to their | ||||||
|  |         # underlying expression. | ||||||
|  |         if not for_save: | ||||||
|  |             return resolved_expression | ||||||
|  |         return DatabaseDefault( | ||||||
|  |             resolved_expression, output_field=self._output_field_or_none | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|  |         if not connection.features.supports_default_keyword_in_insert: | ||||||
|  |             return compiler.compile(self.expression) | ||||||
|         return "DEFAULT", [] |         return "DEFAULT", [] | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -983,13 +983,7 @@ class Field(RegisterLookupMixin): | |||||||
|  |  | ||||||
|     def pre_save(self, model_instance, add): |     def pre_save(self, model_instance, add): | ||||||
|         """Return field's value just before saving.""" |         """Return field's value just before saving.""" | ||||||
|         value = getattr(model_instance, self.attname) |         return getattr(model_instance, self.attname) | ||||||
|         if not connection.features.supports_default_keyword_in_insert: |  | ||||||
|             from django.db.models.expressions import DatabaseDefault |  | ||||||
|  |  | ||||||
|             if isinstance(value, DatabaseDefault): |  | ||||||
|                 return self._db_default_expression |  | ||||||
|         return value |  | ||||||
|  |  | ||||||
|     def get_prep_value(self, value): |     def get_prep_value(self, value): | ||||||
|         """Perform preliminary non-db specific value checks and conversions.""" |         """Perform preliminary non-db specific value checks and conversions.""" | ||||||
| @@ -1031,7 +1025,9 @@ class Field(RegisterLookupMixin): | |||||||
|         if self.db_default is not NOT_PROVIDED: |         if self.db_default is not NOT_PROVIDED: | ||||||
|             from django.db.models.expressions import DatabaseDefault |             from django.db.models.expressions import DatabaseDefault | ||||||
|  |  | ||||||
|             return DatabaseDefault |             return lambda: DatabaseDefault( | ||||||
|  |                 self._db_default_expression, output_field=self | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if ( |         if ( | ||||||
|             not self.empty_strings_allowed |             not self.empty_strings_allowed | ||||||
|   | |||||||
| @@ -28,3 +28,7 @@ Bugfixes | |||||||
| * Fixed a bug in Django 5.0 that caused a system check crash when | * Fixed a bug in Django 5.0 that caused a system check crash when | ||||||
|   ``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an |   ``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an | ||||||
|   ``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`). |   ``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`). | ||||||
|  |  | ||||||
|  | * Fixed a bug in Django 5.0 which caused constraint validation to either crash | ||||||
|  |   or incorrectly raise validation errors for constraints referring to fields | ||||||
|  |   using ``Field.db_default`` (:ticket:`35638`). | ||||||
|   | |||||||
| @@ -128,3 +128,10 @@ class JSONFieldModel(models.Model): | |||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         required_db_features = {"supports_json_field"} |         required_db_features = {"supports_json_field"} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelWithDatabaseDefault(models.Model): | ||||||
|  |     field = models.CharField(max_length=255) | ||||||
|  |     field_with_db_default = models.CharField( | ||||||
|  |         max_length=255, db_default=models.Value("field_with_db_default") | ||||||
|  |     ) | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError | |||||||
| from django.db import IntegrityError, connection, models | from django.db import IntegrityError, connection, models | ||||||
| from django.db.models import F | from django.db.models import F | ||||||
| from django.db.models.constraints import BaseConstraint, UniqueConstraint | from django.db.models.constraints import BaseConstraint, UniqueConstraint | ||||||
| from django.db.models.functions import Abs, Lower | from django.db.models.functions import Abs, Lower, Upper | ||||||
| from django.db.transaction import atomic | from django.db.transaction import atomic | ||||||
| from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature | from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature | ||||||
| from django.test.utils import ignore_warnings | from django.test.utils import ignore_warnings | ||||||
| @@ -14,6 +14,7 @@ from .models import ( | |||||||
|     ChildModel, |     ChildModel, | ||||||
|     ChildUniqueConstraintProduct, |     ChildUniqueConstraintProduct, | ||||||
|     JSONFieldModel, |     JSONFieldModel, | ||||||
|  |     ModelWithDatabaseDefault, | ||||||
|     Product, |     Product, | ||||||
|     UniqueConstraintConditionProduct, |     UniqueConstraintConditionProduct, | ||||||
|     UniqueConstraintDeferrable, |     UniqueConstraintDeferrable, | ||||||
| @@ -396,6 +397,33 @@ class CheckConstraintTests(TestCase): | |||||||
|         with self.assertWarnsRegex(RemovedInDjango60Warning, msg): |         with self.assertWarnsRegex(RemovedInDjango60Warning, msg): | ||||||
|             self.assertIs(constraint.check, other_condition) |             self.assertIs(constraint.check, other_condition) | ||||||
|  |  | ||||||
|  |     def test_database_default(self): | ||||||
|  |         models.CheckConstraint( | ||||||
|  |             condition=models.Q(field_with_db_default="field_with_db_default"), | ||||||
|  |             name="check_field_with_db_default", | ||||||
|  |         ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) | ||||||
|  |  | ||||||
|  |         # Ensure that a check also does not silently pass with either | ||||||
|  |         # FieldError or DatabaseError when checking with a db_default. | ||||||
|  |         with self.assertRaises(ValidationError): | ||||||
|  |             models.CheckConstraint( | ||||||
|  |                 condition=models.Q( | ||||||
|  |                     field_with_db_default="field_with_db_default", field="field" | ||||||
|  |                 ), | ||||||
|  |                 name="check_field_with_db_default_2", | ||||||
|  |             ).validate( | ||||||
|  |                 ModelWithDatabaseDefault, ModelWithDatabaseDefault(field="not-field") | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         with self.assertRaises(ValidationError): | ||||||
|  |             models.CheckConstraint( | ||||||
|  |                 condition=models.Q(field_with_db_default="field_with_db_default"), | ||||||
|  |                 name="check_field_with_db_default", | ||||||
|  |             ).validate( | ||||||
|  |                 ModelWithDatabaseDefault, | ||||||
|  |                 ModelWithDatabaseDefault(field_with_db_default="other value"), | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class UniqueConstraintTests(TestCase): | class UniqueConstraintTests(TestCase): | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -1265,3 +1293,30 @@ class UniqueConstraintTests(TestCase): | |||||||
|         msg = "A unique constraint must be named." |         msg = "A unique constraint must be named." | ||||||
|         with self.assertRaisesMessage(ValueError, msg): |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|             models.UniqueConstraint(fields=["field"]) |             models.UniqueConstraint(fields=["field"]) | ||||||
|  |  | ||||||
|  |     def test_database_default(self): | ||||||
|  |         models.UniqueConstraint( | ||||||
|  |             fields=["field_with_db_default"], name="unique_field_with_db_default" | ||||||
|  |         ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) | ||||||
|  |         models.UniqueConstraint( | ||||||
|  |             Upper("field_with_db_default"), | ||||||
|  |             name="unique_field_with_db_default_expression", | ||||||
|  |         ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) | ||||||
|  |  | ||||||
|  |         ModelWithDatabaseDefault.objects.create() | ||||||
|  |  | ||||||
|  |         msg = ( | ||||||
|  |             "Model with database default with this Field with db default already " | ||||||
|  |             "exists." | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValidationError, msg): | ||||||
|  |             models.UniqueConstraint( | ||||||
|  |                 fields=["field_with_db_default"], name="unique_field_with_db_default" | ||||||
|  |             ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) | ||||||
|  |  | ||||||
|  |         msg = "Constraint “unique_field_with_db_default_expression” is violated." | ||||||
|  |         with self.assertRaisesMessage(ValidationError, msg): | ||||||
|  |             models.UniqueConstraint( | ||||||
|  |                 Upper("field_with_db_default"), | ||||||
|  |                 name="unique_field_with_db_default_expression", | ||||||
|  |             ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) | ||||||
|   | |||||||
| @@ -434,7 +434,7 @@ class Migration(migrations.Migration): | |||||||
|                         primary_key=True, |                         primary_key=True, | ||||||
|                     ), |                     ), | ||||||
|                 ), |                 ), | ||||||
|                 ("ints", IntegerRangeField(null=True, blank=True)), |                 ("ints", IntegerRangeField(null=True, blank=True, db_default=(5, 10))), | ||||||
|                 ("bigints", BigIntegerRangeField(null=True, blank=True)), |                 ("bigints", BigIntegerRangeField(null=True, blank=True)), | ||||||
|                 ("decimals", DecimalRangeField(null=True, blank=True)), |                 ("decimals", DecimalRangeField(null=True, blank=True)), | ||||||
|                 ("timestamps", DateTimeRangeField(null=True, blank=True)), |                 ("timestamps", DateTimeRangeField(null=True, blank=True)), | ||||||
|   | |||||||
| @@ -130,7 +130,7 @@ class LineSavedSearch(PostgreSQLModel): | |||||||
|  |  | ||||||
|  |  | ||||||
| class RangesModel(PostgreSQLModel): | class RangesModel(PostgreSQLModel): | ||||||
|     ints = IntegerRangeField(blank=True, null=True) |     ints = IntegerRangeField(blank=True, null=True, db_default=(5, 10)) | ||||||
|     bigints = BigIntegerRangeField(blank=True, null=True) |     bigints = BigIntegerRangeField(blank=True, null=True) | ||||||
|     decimals = DecimalRangeField(blank=True, null=True) |     decimals = DecimalRangeField(blank=True, null=True) | ||||||
|     timestamps = DateTimeRangeField(blank=True, null=True) |     timestamps = DateTimeRangeField(blank=True, null=True) | ||||||
|   | |||||||
| @@ -1213,3 +1213,12 @@ class ExclusionConstraintTests(PostgreSQLTestCase): | |||||||
|             constraint_name, |             constraint_name, | ||||||
|             self.get_constraints(ModelWithExclusionConstraint._meta.db_table), |             self.get_constraints(ModelWithExclusionConstraint._meta.db_table), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_database_default(self): | ||||||
|  |         constraint = ExclusionConstraint( | ||||||
|  |             name="ints_equal", expressions=[("ints", RangeOperators.EQUAL)] | ||||||
|  |         ) | ||||||
|  |         RangesModel.objects.create() | ||||||
|  |         msg = "Constraint “ints_equal” is violated." | ||||||
|  |         with self.assertRaisesMessage(ValidationError, msg): | ||||||
|  |             constraint.validate(RangesModel, RangesModel()) | ||||||
|   | |||||||
| @@ -48,7 +48,7 @@ class ModelToValidate(models.Model): | |||||||
|  |  | ||||||
| class UniqueFieldsModel(models.Model): | class UniqueFieldsModel(models.Model): | ||||||
|     unique_charfield = models.CharField(max_length=100, unique=True) |     unique_charfield = models.CharField(max_length=100, unique=True) | ||||||
|     unique_integerfield = models.IntegerField(unique=True) |     unique_integerfield = models.IntegerField(unique=True, db_default=42) | ||||||
|     non_unique_field = models.IntegerField() |     non_unique_field = models.IntegerField() | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -146,6 +146,20 @@ class PerformUniqueChecksTest(TestCase): | |||||||
|             mtv = ModelToValidate(number=10, name="Some Name") |             mtv = ModelToValidate(number=10, name="Some Name") | ||||||
|             mtv.full_clean() |             mtv.full_clean() | ||||||
|  |  | ||||||
|  |     def test_unique_db_default(self): | ||||||
|  |         UniqueFieldsModel.objects.create(unique_charfield="foo", non_unique_field=42) | ||||||
|  |         um = UniqueFieldsModel(unique_charfield="bar", non_unique_field=42) | ||||||
|  |         with self.assertRaises(ValidationError) as cm: | ||||||
|  |             um.full_clean() | ||||||
|  |         self.assertEqual( | ||||||
|  |             cm.exception.message_dict, | ||||||
|  |             { | ||||||
|  |                 "unique_integerfield": [ | ||||||
|  |                     "Unique fields model with this Unique integerfield already exists." | ||||||
|  |                 ] | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_unique_for_date(self): |     def test_unique_for_date(self): | ||||||
|         Post.objects.create( |         Post.objects.create( | ||||||
|             title="Django 1.0 is released", |             title="Django 1.0 is released", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user