1
0
mirror of https://github.com/django/django.git synced 2025-06-05 11:39:13 +00:00

Fixed #34338 -- Allowed customizing code of ValidationError in BaseConstraint and subclasses.

This commit is contained in:
Xavier Fernandez 2023-02-14 21:06:45 +01:00 committed by Mariusz Felisiak
parent 51c9bb7cd1
commit 5b3d3e400a
7 changed files with 268 additions and 23 deletions

View File

@ -32,6 +32,7 @@ class ExclusionConstraint(BaseConstraint):
condition=None, condition=None,
deferrable=None, deferrable=None,
include=None, include=None,
violation_error_code=None,
violation_error_message=None, violation_error_message=None,
): ):
if index_type and index_type.lower() not in {"gist", "spgist"}: if index_type and index_type.lower() not in {"gist", "spgist"}:
@ -60,7 +61,11 @@ class ExclusionConstraint(BaseConstraint):
self.condition = condition self.condition = condition
self.deferrable = deferrable self.deferrable = deferrable
self.include = tuple(include) if include else () self.include = tuple(include) if include else ()
super().__init__(name=name, violation_error_message=violation_error_message) super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
def _get_expressions(self, schema_editor, query): def _get_expressions(self, schema_editor, query):
expressions = [] expressions = []
@ -149,12 +154,13 @@ class ExclusionConstraint(BaseConstraint):
and self.condition == other.condition and self.condition == other.condition
and self.deferrable == other.deferrable and self.deferrable == other.deferrable
and self.include == other.include and self.include == other.include
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message and self.violation_error_message == other.violation_error_message
) )
return super().__eq__(other) return super().__eq__(other)
def __repr__(self): def __repr__(self):
return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % ( return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % (
self.__class__.__qualname__, self.__class__.__qualname__,
repr(self.index_type), repr(self.index_type),
repr(self.expressions), repr(self.expressions),
@ -162,6 +168,11 @@ class ExclusionConstraint(BaseConstraint):
"" if self.condition is None else " condition=%s" % self.condition, "" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include), "" if not self.include else " include=%s" % repr(self.include),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
( (
"" ""
if self.violation_error_message is None if self.violation_error_message is None
@ -204,9 +215,13 @@ class ExclusionConstraint(BaseConstraint):
queryset = queryset.exclude(pk=model_class_pk) queryset = queryset.exclude(pk=model_class_pk)
if not self.condition: if not self.condition:
if queryset.exists(): if queryset.exists():
raise ValidationError(self.get_violation_error_message()) raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)
else: else:
if (self.condition & Exists(queryset.filter(self.condition))).check( if (self.condition & Exists(queryset.filter(self.condition))).check(
replacement_map, using=using replacement_map, using=using
): ):
raise ValidationError(self.get_violation_error_message()) raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)

View File

@ -18,11 +18,16 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"
class BaseConstraint: class BaseConstraint:
default_violation_error_message = _("Constraint “%(name)s” is violated.") default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_code = None
violation_error_message = None violation_error_message = None
# RemovedInDjango60Warning: When the deprecation ends, replace with: # RemovedInDjango60Warning: When the deprecation ends, replace with:
# def __init__(self, *, name, violation_error_message=None): # def __init__(
def __init__(self, *args, name=None, violation_error_message=None): # self, *, name, violation_error_code=None, violation_error_message=None
# ):
def __init__(
self, *args, name=None, violation_error_code=None, violation_error_message=None
):
# RemovedInDjango60Warning. # RemovedInDjango60Warning.
if name is None and not args: if name is None and not args:
raise TypeError( raise TypeError(
@ -30,6 +35,8 @@ class BaseConstraint:
f"argument: 'name'" f"argument: 'name'"
) )
self.name = name self.name = name
if violation_error_code is not None:
self.violation_error_code = violation_error_code
if violation_error_message is not None: if violation_error_message is not None:
self.violation_error_message = violation_error_message self.violation_error_message = violation_error_message
else: else:
@ -74,6 +81,8 @@ class BaseConstraint:
and self.violation_error_message != self.default_violation_error_message and self.violation_error_message != self.default_violation_error_message
): ):
kwargs["violation_error_message"] = self.violation_error_message kwargs["violation_error_message"] = self.violation_error_message
if self.violation_error_code is not None:
kwargs["violation_error_code"] = self.violation_error_code
return (path, (), kwargs) return (path, (), kwargs)
def clone(self): def clone(self):
@ -82,13 +91,19 @@ class BaseConstraint:
class CheckConstraint(BaseConstraint): class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name, violation_error_message=None): def __init__(
self, *, check, name, violation_error_code=None, violation_error_message=None
):
self.check = check self.check = check
if not getattr(check, "conditional", False): if not getattr(check, "conditional", False):
raise TypeError( raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression." "CheckConstraint.check must be a Q instance or boolean expression."
) )
super().__init__(name=name, violation_error_message=violation_error_message) super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
def _get_check_sql(self, model, schema_editor): def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False) query = Query(model=model, alias_cols=False)
@ -112,15 +127,22 @@ class CheckConstraint(BaseConstraint):
against = instance._get_field_value_map(meta=model._meta, exclude=exclude) against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try: try:
if not Q(self.check).check(against, using=using): if not Q(self.check).check(against, using=using):
raise ValidationError(self.get_violation_error_message()) raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)
except FieldError: except FieldError:
pass pass
def __repr__(self): def __repr__(self):
return "<%s: check=%s name=%s%s>" % ( return "<%s: check=%s name=%s%s%s>" % (
self.__class__.__qualname__, self.__class__.__qualname__,
self.check, self.check,
repr(self.name), repr(self.name),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
( (
"" ""
if self.violation_error_message is None if self.violation_error_message is None
@ -134,6 +156,7 @@ class CheckConstraint(BaseConstraint):
return ( return (
self.name == other.name self.name == other.name
and self.check == other.check and self.check == other.check
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message and self.violation_error_message == other.violation_error_message
) )
return super().__eq__(other) return super().__eq__(other)
@ -163,6 +186,7 @@ class UniqueConstraint(BaseConstraint):
deferrable=None, deferrable=None,
include=None, include=None,
opclasses=(), opclasses=(),
violation_error_code=None,
violation_error_message=None, violation_error_message=None,
): ):
if not name: if not name:
@ -213,7 +237,11 @@ class UniqueConstraint(BaseConstraint):
F(expression) if isinstance(expression, str) else expression F(expression) if isinstance(expression, str) else expression
for expression in expressions for expression in expressions
) )
super().__init__(name=name, violation_error_message=violation_error_message) super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
@property @property
def contains_expressions(self): def contains_expressions(self):
@ -293,7 +321,7 @@ class UniqueConstraint(BaseConstraint):
) )
def __repr__(self): def __repr__(self):
return "<%s:%s%s%s%s%s%s%s%s>" % ( return "<%s:%s%s%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__, self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields), "" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions), "" if not self.expressions else " expressions=%s" % repr(self.expressions),
@ -302,6 +330,11 @@ class UniqueConstraint(BaseConstraint):
"" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include), "" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
( (
"" ""
if self.violation_error_message is None if self.violation_error_message is None
@ -320,6 +353,7 @@ class UniqueConstraint(BaseConstraint):
and self.include == other.include and self.include == other.include
and self.opclasses == other.opclasses and self.opclasses == other.opclasses
and self.expressions == other.expressions and self.expressions == other.expressions
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message and self.violation_error_message == other.violation_error_message
) )
return super().__eq__(other) return super().__eq__(other)
@ -385,14 +419,17 @@ class UniqueConstraint(BaseConstraint):
if not self.condition: if not self.condition:
if queryset.exists(): if queryset.exists():
if self.expressions: if self.expressions:
raise ValidationError(self.get_violation_error_message()) raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
# When fields are defined, use the unique_error_message() for # When fields are defined, use the unique_error_message() for
# backward compatibility. # backward compatibility.
for model, constraints in instance.get_constraints(): for model, constraints in instance.get_constraints():
for constraint in constraints: for constraint in constraints:
if constraint is self: if constraint is self:
raise ValidationError( raise ValidationError(
instance.unique_error_message(model, self.fields) instance.unique_error_message(model, self.fields),
) )
else: else:
against = instance._get_field_value_map(meta=model._meta, exclude=exclude) against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
@ -400,6 +437,9 @@ class UniqueConstraint(BaseConstraint):
if (self.condition & Exists(queryset.filter(self.condition))).check( if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using against, using=using
): ):
raise ValidationError(self.get_violation_error_message()) raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
except FieldError: except FieldError:
pass pass

View File

@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
``ExclusionConstraint`` ``ExclusionConstraint``
======================= =======================
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_message=None) .. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_code=None, violation_error_message=None)
Creates an exclusion constraint in the database. Internally, PostgreSQL Creates an exclusion constraint in the database. Internally, PostgreSQL
implements exclusion constraints using indexes. The default index type is implements exclusion constraints using indexes. The default index type is
@ -133,6 +133,16 @@ used for queries that select only included fields
``include`` is supported for GiST indexes. PostgreSQL 14+ also supports ``include`` is supported for GiST indexes. PostgreSQL 14+ also supports
``include`` for SP-GiST indexes. ``include`` for SP-GiST indexes.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: ExclusionConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
``violation_error_message`` ``violation_error_message``
--------------------------- ---------------------------

View File

@ -48,7 +48,7 @@ option.
``BaseConstraint`` ``BaseConstraint``
================== ==================
.. class:: BaseConstraint(*, name, violation_error_message=None) .. class:: BaseConstraint(* name, violation_error_code=None, violation_error_message=None)
Base class for all constraints. Subclasses must implement Base class for all constraints. Subclasses must implement
``constraint_sql()``, ``create_sql()``, ``remove_sql()`` and ``constraint_sql()``, ``create_sql()``, ``remove_sql()`` and
@ -68,6 +68,16 @@ All constraints have the following parameters in common:
The name of the constraint. You must always specify a unique name for the The name of the constraint. You must always specify a unique name for the
constraint. constraint.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: BaseConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
``violation_error_message`` ``violation_error_message``
--------------------------- ---------------------------
@ -94,7 +104,7 @@ This method must be implemented by a subclass.
``CheckConstraint`` ``CheckConstraint``
=================== ===================
.. class:: CheckConstraint(*, check, name, violation_error_message=None) .. class:: CheckConstraint(*, check, name, violation_error_code=None, violation_error_message=None)
Creates a check constraint in the database. Creates a check constraint in the database.
@ -121,7 +131,7 @@ ensures the age field is never less than 18.
``UniqueConstraint`` ``UniqueConstraint``
==================== ====================
.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_message=None) .. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_code=None, violation_error_message=None)
Creates a unique constraint in the database. Creates a unique constraint in the database.
@ -242,6 +252,22 @@ creates a unique index on ``username`` using ``varchar_pattern_ops``.
``opclasses`` are ignored for databases besides PostgreSQL. ``opclasses`` are ignored for databases besides PostgreSQL.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: UniqueConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
This code is *not used* for :class:`UniqueConstraint`\s with
:attr:`~UniqueConstraint.fields` and without a
:attr:`~UniqueConstraint.condition`. Such :class:`~UniqueConstraint`\s have the
same error code as constraints defined with :attr:`.Field.unique` or in
:attr:`Meta.unique_together <django.db.models.Options.constraints>`.
``violation_error_message`` ``violation_error_message``
--------------------------- ---------------------------

View File

@ -78,7 +78,10 @@ Minor features
:mod:`django.contrib.postgres` :mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ... * The new :attr:`~.ExclusionConstraint.violation_error_code` attribute of
:class:`~django.contrib.postgres.constraints.ExclusionConstraint` allows
customizing the ``code`` of ``ValidationError`` raised during
:ref:`model validation <validating-objects>`.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -182,6 +185,13 @@ Models
and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different
field values for the create operation. field values for the create operation.
* The new ``violation_error_code`` attribute of
:class:`~django.db.models.BaseConstraint`,
:class:`~django.db.models.CheckConstraint`, and
:class:`~django.db.models.UniqueConstraint` allows customizing the ``code``
of ``ValidationError`` raised during
:ref:`model validation <validating-objects>`.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -77,17 +77,26 @@ class BaseConstraintTests(SimpleTestCase):
"custom base_name message", "custom base_name message",
) )
def test_custom_violation_code_message(self):
c = BaseConstraint(name="base_name", violation_error_code="custom_code")
self.assertEqual(c.violation_error_code, "custom_code")
def test_deconstruction(self): def test_deconstruction(self):
constraint = BaseConstraint( constraint = BaseConstraint(
name="base_name", name="base_name",
violation_error_message="custom %(name)s message", violation_error_message="custom %(name)s message",
violation_error_code="custom_code",
) )
path, args, kwargs = constraint.deconstruct() path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, "django.db.models.BaseConstraint") self.assertEqual(path, "django.db.models.BaseConstraint")
self.assertEqual(args, ()) self.assertEqual(args, ())
self.assertEqual( self.assertEqual(
kwargs, kwargs,
{"name": "base_name", "violation_error_message": "custom %(name)s message"}, {
"name": "base_name",
"violation_error_message": "custom %(name)s message",
"violation_error_code": "custom_code",
},
) )
def test_deprecation(self): def test_deprecation(self):
@ -148,6 +157,20 @@ class CheckConstraintTests(TestCase):
check=check1, name="price", violation_error_message="custom error" check=check1, name="price", violation_error_message="custom error"
), ),
) )
self.assertNotEqual(
models.CheckConstraint(check=check1, name="price"),
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
)
self.assertEqual(
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
)
def test_repr(self): def test_repr(self):
constraint = models.CheckConstraint( constraint = models.CheckConstraint(
@ -172,6 +195,18 @@ class CheckConstraintTests(TestCase):
"violation_error_message='More than 1'>", "violation_error_message='More than 1'>",
) )
def test_repr_with_violation_error_code(self):
constraint = models.CheckConstraint(
check=models.Q(price__lt=1),
name="price_lt_one",
violation_error_code="more_than_one",
)
self.assertEqual(
repr(constraint),
"<CheckConstraint: check=(AND: ('price__lt', 1)) name='price_lt_one' "
"violation_error_code='more_than_one'>",
)
def test_invalid_check_types(self): def test_invalid_check_types(self):
msg = "CheckConstraint.check must be a Q instance or boolean expression." msg = "CheckConstraint.check must be a Q instance or boolean expression."
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
@ -237,6 +272,21 @@ class CheckConstraintTests(TestCase):
# Valid product. # Valid product.
constraint.validate(Product, Product(price=10, discounted_price=5)) constraint.validate(Product, Product(price=10, discounted_price=5))
def test_validate_custom_error(self):
check = models.Q(price__gt=models.F("discounted_price"))
constraint = models.CheckConstraint(
check=check,
name="price",
violation_error_message="discount is fake",
violation_error_code="fake_discount",
)
# Invalid product.
invalid_product = Product(price=10, discounted_price=42)
msg = "discount is fake"
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(Product, invalid_product)
self.assertEqual(cm.exception.code, "fake_discount")
def test_validate_boolean_expressions(self): def test_validate_boolean_expressions(self):
constraint = models.CheckConstraint( constraint = models.CheckConstraint(
check=models.expressions.ExpressionWrapper( check=models.expressions.ExpressionWrapper(
@ -341,6 +391,30 @@ class UniqueConstraintTests(TestCase):
violation_error_message="custom error", violation_error_message="custom error",
), ),
) )
self.assertNotEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="other_custom_error",
),
)
self.assertEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
)
def test_eq_with_condition(self): def test_eq_with_condition(self):
self.assertEqual( self.assertEqual(
@ -512,6 +586,20 @@ class UniqueConstraintTests(TestCase):
), ),
) )
def test_repr_with_violation_error_code(self):
constraint = models.UniqueConstraint(
models.F("baz__lower"),
name="unique_lower_baz",
violation_error_code="baz",
)
self.assertEqual(
repr(constraint),
(
"<UniqueConstraint: expressions=(F(baz__lower),) "
"name='unique_lower_baz' violation_error_code='baz'>"
),
)
def test_deconstruction(self): def test_deconstruction(self):
fields = ["foo", "bar"] fields = ["foo", "bar"]
name = "unique_fields" name = "unique_fields"
@ -656,12 +744,16 @@ class UniqueConstraintTests(TestCase):
def test_validate(self): def test_validate(self):
constraint = UniqueConstraintProduct._meta.constraints[0] constraint = UniqueConstraintProduct._meta.constraints[0]
# Custom message and error code are ignored.
constraint.violation_error_message = "Custom message"
constraint.violation_error_code = "custom_code"
msg = "Unique constraint product with this Name and Color already exists." msg = "Unique constraint product with this Name and Color already exists."
non_unique_product = UniqueConstraintProduct( non_unique_product = UniqueConstraintProduct(
name=self.p1.name, color=self.p1.color name=self.p1.name, color=self.p1.color
) )
with self.assertRaisesMessage(ValidationError, msg): with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(UniqueConstraintProduct, non_unique_product) constraint.validate(UniqueConstraintProduct, non_unique_product)
self.assertEqual(cm.exception.code, "unique_together")
# Null values are ignored. # Null values are ignored.
constraint.validate( constraint.validate(
UniqueConstraintProduct, UniqueConstraintProduct,
@ -716,6 +808,20 @@ class UniqueConstraintTests(TestCase):
exclude={"name"}, exclude={"name"},
) )
@skipUnlessDBFeature("supports_partial_indexes")
def test_validate_conditon_custom_error(self):
p1 = UniqueConstraintConditionProduct.objects.create(name="p1")
constraint = UniqueConstraintConditionProduct._meta.constraints[0]
constraint.violation_error_message = "Custom message"
constraint.violation_error_code = "custom_code"
msg = "Custom message"
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(
UniqueConstraintConditionProduct,
UniqueConstraintConditionProduct(name=p1.name, color=None),
)
self.assertEqual(cm.exception.code, "custom_code")
def test_validate_expression(self): def test_validate_expression(self):
constraint = models.UniqueConstraint(Lower("name"), name="name_lower_uniq") constraint = models.UniqueConstraint(Lower("name"), name="name_lower_uniq")
msg = "Constraint “name_lower_uniq” is violated." msg = "Constraint “name_lower_uniq” is violated."

View File

@ -397,6 +397,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"(F(datespan), '-|-')] name='exclude_overlapping' " "(F(datespan), '-|-')] name='exclude_overlapping' "
"violation_error_message='Overlapping must be excluded'>", "violation_error_message='Overlapping must be excluded'>",
) )
constraint = ExclusionConstraint(
name="exclude_overlapping",
expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
violation_error_code="overlapping_must_be_excluded",
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type='GIST' expressions=["
"(F(datespan), '-|-')] name='exclude_overlapping' "
"violation_error_code='overlapping_must_be_excluded'>",
)
def test_eq(self): def test_eq(self):
constraint_1 = ExclusionConstraint( constraint_1 = ExclusionConstraint(
@ -470,6 +481,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
condition=Q(cancelled=False), condition=Q(cancelled=False),
violation_error_message="other custom error", violation_error_message="other custom error",
) )
constraint_12 = ExclusionConstraint(
name="exclude_overlapping",
expressions=[
(F("datespan"), RangeOperators.OVERLAPS),
(F("room"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
violation_error_code="custom_code",
violation_error_message="other custom error",
)
self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY) self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_2)
@ -483,7 +504,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertNotEqual(constraint_5, constraint_6) self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_1, object()) self.assertNotEqual(constraint_1, object())
self.assertNotEqual(constraint_10, constraint_11) self.assertNotEqual(constraint_10, constraint_11)
self.assertNotEqual(constraint_11, constraint_12)
self.assertEqual(constraint_10, constraint_10) self.assertEqual(constraint_10, constraint_10)
self.assertEqual(constraint_12, constraint_12)
def test_deconstruct(self): def test_deconstruct(self):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
@ -760,17 +783,32 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
name="ints_adjacent", name="ints_adjacent",
expressions=[("ints", RangeOperators.ADJACENT_TO)], expressions=[("ints", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
violation_error_message="Custom error message.", violation_error_message="Custom error message.",
) )
range_obj = RangesModel.objects.create(ints=(20, 50)) range_obj = RangesModel.objects.create(ints=(20, 50))
constraint.validate(RangesModel, range_obj) constraint.validate(RangesModel, range_obj)
msg = "Custom error message." msg = "Custom error message."
with self.assertRaisesMessage(ValidationError, msg): with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(RangesModel, RangesModel(ints=(10, 20))) constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
self.assertEqual(cm.exception.code, "custom_code")
constraint.validate(RangesModel, RangesModel(ints=(10, 19))) constraint.validate(RangesModel, RangesModel(ints=(10, 19)))
constraint.validate(RangesModel, RangesModel(ints=(51, 60))) constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"}) constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
def test_validate_with_custom_code_and_condition(self):
constraint = ExclusionConstraint(
name="ints_adjacent",
expressions=[("ints", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
condition=Q(ints__lt=(100, 200)),
)
range_obj = RangesModel.objects.create(ints=(20, 50))
constraint.validate(RangesModel, range_obj)
with self.assertRaises(ValidationError) as cm:
constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
self.assertEqual(cm.exception.code, "custom_code")
def test_expressions_with_params(self): def test_expressions_with_params(self):
constraint_name = "scene_left_equal" constraint_name = "scene_left_equal"
self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))