1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #35638 -- Updated validate_constraints to consider db_default.

This commit is contained in:
David Sanders 2024-08-05 08:22:29 +02:00 committed by Sarah Boyce
parent 91a038754b
commit 509763c799
10 changed files with 130 additions and 13 deletions

View File

@ -1250,9 +1250,41 @@ class Star(Expression):
class DatabaseDefault(Expression): class DatabaseDefault(Expression):
"""Placeholder expression for the database default in an insert query.""" """
Expression to use DEFAULT keyword during insert otherwise the underlying expression.
"""
def __init__(self, expression, output_field=None):
super().__init__(output_field)
self.expression = expression
def get_source_expressions(self):
return [self.expression]
def set_source_expressions(self, exprs):
(self.expression,) = exprs
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
resolved_expression = self.expression.resolve_expression(
query=query,
allow_joins=allow_joins,
reuse=reuse,
summarize=summarize,
for_save=for_save,
)
# Defaults used outside an INSERT context should resolve to their
# underlying expression.
if not for_save:
return resolved_expression
return DatabaseDefault(
resolved_expression, output_field=self._output_field_or_none
)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if not connection.features.supports_default_keyword_in_insert:
return compiler.compile(self.expression)
return "DEFAULT", [] return "DEFAULT", []

View File

@ -983,13 +983,7 @@ class Field(RegisterLookupMixin):
def pre_save(self, model_instance, add): def pre_save(self, model_instance, add):
"""Return field's value just before saving.""" """Return field's value just before saving."""
value = getattr(model_instance, self.attname) return getattr(model_instance, self.attname)
if not connection.features.supports_default_keyword_in_insert:
from django.db.models.expressions import DatabaseDefault
if isinstance(value, DatabaseDefault):
return self._db_default_expression
return value
def get_prep_value(self, value): def get_prep_value(self, value):
"""Perform preliminary non-db specific value checks and conversions.""" """Perform preliminary non-db specific value checks and conversions."""
@ -1031,7 +1025,9 @@ class Field(RegisterLookupMixin):
if self.db_default is not NOT_PROVIDED: if self.db_default is not NOT_PROVIDED:
from django.db.models.expressions import DatabaseDefault from django.db.models.expressions import DatabaseDefault
return DatabaseDefault return lambda: DatabaseDefault(
self._db_default_expression, output_field=self
)
if ( if (
not self.empty_strings_allowed not self.empty_strings_allowed

View File

@ -28,3 +28,7 @@ Bugfixes
* Fixed a bug in Django 5.0 that caused a system check crash when * Fixed a bug in Django 5.0 that caused a system check crash when
``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an ``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an
``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`). ``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`).
* Fixed a bug in Django 5.0 which caused constraint validation to either crash
or incorrectly raise validation errors for constraints referring to fields
using ``Field.db_default`` (:ticket:`35638`).

View File

@ -128,3 +128,10 @@ class JSONFieldModel(models.Model):
class Meta: class Meta:
required_db_features = {"supports_json_field"} required_db_features = {"supports_json_field"}
class ModelWithDatabaseDefault(models.Model):
field = models.CharField(max_length=255)
field_with_db_default = models.CharField(
max_length=255, db_default=models.Value("field_with_db_default")
)

View File

@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models import F from django.db.models import F
from django.db.models.constraints import BaseConstraint, UniqueConstraint from django.db.models.constraints import BaseConstraint, UniqueConstraint
from django.db.models.functions import Abs, Lower from django.db.models.functions import Abs, Lower, Upper
from django.db.transaction import atomic from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings from django.test.utils import ignore_warnings
@ -14,6 +14,7 @@ from .models import (
ChildModel, ChildModel,
ChildUniqueConstraintProduct, ChildUniqueConstraintProduct,
JSONFieldModel, JSONFieldModel,
ModelWithDatabaseDefault,
Product, Product,
UniqueConstraintConditionProduct, UniqueConstraintConditionProduct,
UniqueConstraintDeferrable, UniqueConstraintDeferrable,
@ -396,6 +397,33 @@ class CheckConstraintTests(TestCase):
with self.assertWarnsRegex(RemovedInDjango60Warning, msg): with self.assertWarnsRegex(RemovedInDjango60Warning, msg):
self.assertIs(constraint.check, other_condition) self.assertIs(constraint.check, other_condition)
def test_database_default(self):
models.CheckConstraint(
condition=models.Q(field_with_db_default="field_with_db_default"),
name="check_field_with_db_default",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
# Ensure that a check also does not silently pass with either
# FieldError or DatabaseError when checking with a db_default.
with self.assertRaises(ValidationError):
models.CheckConstraint(
condition=models.Q(
field_with_db_default="field_with_db_default", field="field"
),
name="check_field_with_db_default_2",
).validate(
ModelWithDatabaseDefault, ModelWithDatabaseDefault(field="not-field")
)
with self.assertRaises(ValidationError):
models.CheckConstraint(
condition=models.Q(field_with_db_default="field_with_db_default"),
name="check_field_with_db_default",
).validate(
ModelWithDatabaseDefault,
ModelWithDatabaseDefault(field_with_db_default="other value"),
)
class UniqueConstraintTests(TestCase): class UniqueConstraintTests(TestCase):
@classmethod @classmethod
@ -1265,3 +1293,30 @@ class UniqueConstraintTests(TestCase):
msg = "A unique constraint must be named." msg = "A unique constraint must be named."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(fields=["field"]) models.UniqueConstraint(fields=["field"])
def test_database_default(self):
models.UniqueConstraint(
fields=["field_with_db_default"], name="unique_field_with_db_default"
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
models.UniqueConstraint(
Upper("field_with_db_default"),
name="unique_field_with_db_default_expression",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
ModelWithDatabaseDefault.objects.create()
msg = (
"Model with database default with this Field with db default already "
"exists."
)
with self.assertRaisesMessage(ValidationError, msg):
models.UniqueConstraint(
fields=["field_with_db_default"], name="unique_field_with_db_default"
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
msg = "Constraint “unique_field_with_db_default_expression” is violated."
with self.assertRaisesMessage(ValidationError, msg):
models.UniqueConstraint(
Upper("field_with_db_default"),
name="unique_field_with_db_default_expression",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())

View File

@ -434,7 +434,7 @@ class Migration(migrations.Migration):
primary_key=True, primary_key=True,
), ),
), ),
("ints", IntegerRangeField(null=True, blank=True)), ("ints", IntegerRangeField(null=True, blank=True, db_default=(5, 10))),
("bigints", BigIntegerRangeField(null=True, blank=True)), ("bigints", BigIntegerRangeField(null=True, blank=True)),
("decimals", DecimalRangeField(null=True, blank=True)), ("decimals", DecimalRangeField(null=True, blank=True)),
("timestamps", DateTimeRangeField(null=True, blank=True)), ("timestamps", DateTimeRangeField(null=True, blank=True)),

View File

@ -130,7 +130,7 @@ class LineSavedSearch(PostgreSQLModel):
class RangesModel(PostgreSQLModel): class RangesModel(PostgreSQLModel):
ints = IntegerRangeField(blank=True, null=True) ints = IntegerRangeField(blank=True, null=True, db_default=(5, 10))
bigints = BigIntegerRangeField(blank=True, null=True) bigints = BigIntegerRangeField(blank=True, null=True)
decimals = DecimalRangeField(blank=True, null=True) decimals = DecimalRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True) timestamps = DateTimeRangeField(blank=True, null=True)

View File

@ -1213,3 +1213,12 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint_name, constraint_name,
self.get_constraints(ModelWithExclusionConstraint._meta.db_table), self.get_constraints(ModelWithExclusionConstraint._meta.db_table),
) )
def test_database_default(self):
constraint = ExclusionConstraint(
name="ints_equal", expressions=[("ints", RangeOperators.EQUAL)]
)
RangesModel.objects.create()
msg = "Constraint “ints_equal” is violated."
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(RangesModel, RangesModel())

View File

@ -48,7 +48,7 @@ class ModelToValidate(models.Model):
class UniqueFieldsModel(models.Model): class UniqueFieldsModel(models.Model):
unique_charfield = models.CharField(max_length=100, unique=True) unique_charfield = models.CharField(max_length=100, unique=True)
unique_integerfield = models.IntegerField(unique=True) unique_integerfield = models.IntegerField(unique=True, db_default=42)
non_unique_field = models.IntegerField() non_unique_field = models.IntegerField()

View File

@ -146,6 +146,20 @@ class PerformUniqueChecksTest(TestCase):
mtv = ModelToValidate(number=10, name="Some Name") mtv = ModelToValidate(number=10, name="Some Name")
mtv.full_clean() mtv.full_clean()
def test_unique_db_default(self):
UniqueFieldsModel.objects.create(unique_charfield="foo", non_unique_field=42)
um = UniqueFieldsModel(unique_charfield="bar", non_unique_field=42)
with self.assertRaises(ValidationError) as cm:
um.full_clean()
self.assertEqual(
cm.exception.message_dict,
{
"unique_integerfield": [
"Unique fields model with this Unique integerfield already exists."
]
},
)
def test_unique_for_date(self): def test_unique_for_date(self):
Post.objects.create( Post.objects.create(
title="Django 1.0 is released", title="Django 1.0 is released",