From f030236a86a64a4befd3cc8093e2bbeceef52a31 Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Sun, 12 May 2024 20:10:55 +0200 Subject: [PATCH] Fixed #35275 -- Fixed Meta.constraints validation crash on UniqueConstraint with OpClass(). This also introduces Expression.constraint_validation_compatible that allows specifying that expression should be ignored during a constraint validation. --- django/contrib/postgres/constraints.py | 11 ++++------- django/contrib/postgres/indexes.py | 1 + django/db/models/constraints.py | 7 +++---- django/db/models/expressions.py | 17 +++++++++++++++++ docs/ref/models/expressions.txt | 9 +++++++++ docs/releases/5.1.txt | 4 ++++ tests/expressions/tests.py | 10 ++++++++++ tests/postgres_tests/test_constraints.py | 13 +++++++++++++ 8 files changed, 61 insertions(+), 11 deletions(-) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index ff702c53b0..a31f657183 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -1,6 +1,5 @@ from types import NoneType -from django.contrib.postgres.indexes import OpClass from django.core.exceptions import ValidationError from django.db import DEFAULT_DB_ALIAS, NotSupportedError from django.db.backends.ddl_references import Expressions, Statement, Table @@ -208,12 +207,10 @@ class ExclusionConstraint(BaseConstraint): if isinstance(expr, F) and expr.name in exclude: return rhs_expression = expression.replace_expressions(replacements) - # Remove OpClass because it only has sense during the constraint - # creation. - if isinstance(expression, OpClass): - expression = expression.get_source_expressions()[0] - if isinstance(rhs_expression, OpClass): - rhs_expression = rhs_expression.get_source_expressions()[0] + if hasattr(expression, "get_expression_for_validation"): + expression = expression.get_expression_for_validation() + if hasattr(rhs_expression, "get_expression_for_validation"): + rhs_expression = rhs_expression.get_expression_for_validation() lookup = PostgresOperatorLookup(lhs=expression, rhs=rhs_expression) lookup.postgres_operator = operator lookups.append(lookup) diff --git a/django/contrib/postgres/indexes.py b/django/contrib/postgres/indexes.py index cc944ed335..05fdbeed5e 100644 --- a/django/contrib/postgres/indexes.py +++ b/django/contrib/postgres/indexes.py @@ -244,6 +244,7 @@ class SpGistIndex(PostgresIndex): class OpClass(Func): template = "%(expressions)s %(name)s" + constraint_validation_compatible = False def __init__(self, expression, name): super().__init__(expression, name=name) diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 9c63a0940d..3e6c5205c6 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -6,7 +6,7 @@ from django.core import checks from django.core.exceptions import FieldDoesNotExist, FieldError, ValidationError from django.db import connections from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import Exists, ExpressionList, F, OrderBy, RawSQL +from django.db.models.expressions import Exists, ExpressionList, F, RawSQL from django.db.models.indexes import IndexExpression from django.db.models.lookups import Exact from django.db.models.query_utils import Q @@ -644,9 +644,8 @@ class UniqueConstraint(BaseConstraint): } expressions = [] for expr in self.expressions: - # Ignore ordering. - if isinstance(expr, OrderBy): - expr = expr.expression + if hasattr(expr, "get_expression_for_validation"): + expr = expr.get_expression_for_validation() expressions.append(Exact(expr, expr.replace_expressions(replacements))) queryset = queryset.filter(*expressions) model_class_pk = instance._get_pk_val(model._meta) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 6032b4d1f4..4ee22420d9 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -180,6 +180,8 @@ class BaseExpression: window_compatible = False # Can the expression be used as a database default value? allowed_default = False + # Can the expression be used during a constraint validation? + constraint_validation_compatible = True def __init__(self, output_field=None): if output_field is not None: @@ -484,6 +486,20 @@ class BaseExpression: return self.output_field.select_format(compiler, sql, params) return sql, params + def get_expression_for_validation(self): + # Ignore expressions that cannot be used during a constraint validation. + if not getattr(self, "constraint_validation_compatible", True): + try: + (expression,) = self.get_source_expressions() + except ValueError as e: + raise ValueError( + "Expressions with constraint_validation_compatible set to False " + "must have only one source expression." + ) from e + else: + return expression + return self + @deconstructible class Expression(BaseExpression, Combinable): @@ -1716,6 +1732,7 @@ class Exists(Subquery): class OrderBy(Expression): template = "%(expression)s %(ordering)s" conditional = False + constraint_validation_compatible = False def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None): if nulls_first and nulls_last: diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 67baef7dfc..f630142294 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -1058,6 +1058,15 @@ calling the appropriate methods on the wrapped expression. Tells Django that this expression can be used in :attr:`Field.db_default`. Defaults to ``False``. + .. attribute:: constraint_validation_compatible + + .. versionadded:: 5.1 + + Tells Django that this expression can be used during a constraint + validation. Expressions with ``constraint_validation_compatible`` set + to ``False`` must have only one source expression. Defaults to + ``True``. + .. attribute:: contains_aggregate Tells Django that this expression contains an aggregate and that a diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index f2a6bccb0c..f068f3e96b 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -281,6 +281,10 @@ Models reload a model's value. This can be used to lock the row before reloading or to select related objects. +* The new :attr:`.Expression.constraint_validation_compatible` attribute allows + specifying that the expression should be ignored during a constraint + validation. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index f7233305a7..3538900092 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1425,6 +1425,16 @@ class SimpleExpressionTests(SimpleTestCase): hash(Expression(TestModel._meta.get_field("other_field"))), ) + def test_get_expression_for_validation_only_one_source_expression(self): + expression = Expression() + expression.constraint_validation_compatible = False + msg = ( + "Expressions with constraint_validation_compatible set to False must have " + "only one source expression." + ) + with self.assertRaisesMessage(ValueError, msg): + expression.get_expression_for_validation() + class ExpressionsNumericTests(TestCase): @classmethod diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index b3de53efd7..3cc76cdcfe 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -266,6 +266,19 @@ class SchemaTests(PostgreSQLTestCase): self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table)) Scene.objects.create(scene="ScEnE 10", setting="Sir Bedemir's Castle") + def test_opclass_func_validate_constraints(self): + constraint_name = "test_opclass_func_validate_constraints" + constraint = UniqueConstraint( + OpClass(Lower("scene"), name="text_pattern_ops"), + name="test_opclass_func_validate_constraints", + ) + Scene.objects.create(scene="First scene") + # Non-unique scene. + msg = f"Constraint “{constraint_name}” is violated." + with self.assertRaisesMessage(ValidationError, msg): + constraint.validate(Scene, Scene(scene="first Scene")) + constraint.validate(Scene, Scene(scene="second Scene")) + class ExclusionConstraintTests(PostgreSQLTestCase): def get_constraints(self, table):