mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Fixed #30484 -- Added conditional expressions support to CheckConstraint.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							37e6c5b79b
						
					
				
				
					commit
					e9a0e1d4f6
				
			| @@ -30,6 +30,11 @@ class BaseConstraint: | ||||
| class CheckConstraint(BaseConstraint): | ||||
|     def __init__(self, *, check, name): | ||||
|         self.check = check | ||||
|         if not getattr(check, 'conditional', False): | ||||
|             raise TypeError( | ||||
|                 'CheckConstraint.check must be a Q instance or boolean ' | ||||
|                 'expression.' | ||||
|             ) | ||||
|         super().__init__(name) | ||||
|  | ||||
|     def _get_check_sql(self, model, schema_editor): | ||||
|   | ||||
| @@ -1221,8 +1221,19 @@ class Query(BaseExpression): | ||||
|         """ | ||||
|         if isinstance(filter_expr, dict): | ||||
|             raise FieldError("Cannot parse keyword query as dict") | ||||
|         if isinstance(filter_expr, Q): | ||||
|             return self._add_q( | ||||
|                 filter_expr, | ||||
|                 branch_negated=branch_negated, | ||||
|                 current_negated=current_negated, | ||||
|                 used_aliases=can_reuse, | ||||
|                 allow_joins=allow_joins, | ||||
|                 split_subq=split_subq, | ||||
|             ) | ||||
|         if hasattr(filter_expr, 'resolve_expression') and getattr(filter_expr, 'conditional', False): | ||||
|             condition = self.build_lookup(['exact'], filter_expr.resolve_expression(self), True) | ||||
|             condition = self.build_lookup( | ||||
|                 ['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True | ||||
|             ) | ||||
|             clause = self.where_class() | ||||
|             clause.add(condition, AND) | ||||
|             return clause, [] | ||||
| @@ -1332,8 +1343,8 @@ class Query(BaseExpression): | ||||
|             self.where.add(clause, AND) | ||||
|         self.demote_joins(existing_inner) | ||||
|  | ||||
|     def build_where(self, q_object): | ||||
|         return self._add_q(q_object, used_aliases=set(), allow_joins=False)[0] | ||||
|     def build_where(self, filter_expr): | ||||
|         return self.build_filter(filter_expr, allow_joins=False)[0] | ||||
|  | ||||
|     def _add_q(self, q_object, used_aliases, branch_negated=False, | ||||
|                current_negated=False, allow_joins=True, split_subq=True): | ||||
| @@ -1345,18 +1356,12 @@ class Query(BaseExpression): | ||||
|                                          negated=q_object.negated) | ||||
|         joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated) | ||||
|         for child in q_object.children: | ||||
|             if isinstance(child, Node): | ||||
|                 child_clause, needed_inner = self._add_q( | ||||
|                     child, used_aliases, branch_negated, | ||||
|                     current_negated, allow_joins, split_subq) | ||||
|                 joinpromoter.add_votes(needed_inner) | ||||
|             else: | ||||
|                 child_clause, needed_inner = self.build_filter( | ||||
|                     child, can_reuse=used_aliases, branch_negated=branch_negated, | ||||
|                     current_negated=current_negated, allow_joins=allow_joins, | ||||
|                     split_subq=split_subq, | ||||
|                 ) | ||||
|                 joinpromoter.add_votes(needed_inner) | ||||
|             child_clause, needed_inner = self.build_filter( | ||||
|                 child, can_reuse=used_aliases, branch_negated=branch_negated, | ||||
|                 current_negated=current_negated, allow_joins=allow_joins, | ||||
|                 split_subq=split_subq, | ||||
|             ) | ||||
|             joinpromoter.add_votes(needed_inner) | ||||
|             if child_clause: | ||||
|                 target_clause.add(child_clause, connector) | ||||
|         needed_inner = joinpromoter.update_join_types(self) | ||||
|   | ||||
| @@ -52,12 +52,16 @@ option. | ||||
|  | ||||
| .. attribute:: CheckConstraint.check | ||||
|  | ||||
| A :class:`Q` object that specifies the check you want the constraint to | ||||
| enforce. | ||||
| A :class:`Q` object or boolean :class:`~django.db.models.Expression` that | ||||
| specifies the check you want the constraint to enforce. | ||||
|  | ||||
| For example, ``CheckConstraint(check=Q(age__gte=18), name='age_gte_18')`` | ||||
| ensures the age field is never less than 18. | ||||
|  | ||||
| .. versionchanged:: 3.1 | ||||
|  | ||||
|     Support for boolean :class:`~django.db.models.Expression` was added. | ||||
|  | ||||
| ``name`` | ||||
| -------- | ||||
|  | ||||
|   | ||||
| @@ -204,6 +204,8 @@ Models | ||||
|   ``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE | ||||
|   RESTRICT``. | ||||
|  | ||||
| * :attr:`.CheckConstraint.check` now supports boolean expressions. | ||||
|  | ||||
| Pagination | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -18,6 +18,19 @@ class Product(models.Model): | ||||
|                 check=models.Q(price__gt=0), | ||||
|                 name='%(app_label)s_%(class)s_price_gt_0', | ||||
|             ), | ||||
|             models.CheckConstraint( | ||||
|                 check=models.expressions.RawSQL( | ||||
|                     'price < %s', (1000,), output_field=models.BooleanField() | ||||
|                 ), | ||||
|                 name='%(app_label)s_price_lt_1000_raw', | ||||
|             ), | ||||
|             models.CheckConstraint( | ||||
|                 check=models.expressions.ExpressionWrapper( | ||||
|                     models.Q(price__gt=500) | models.Q(price__lt=500), | ||||
|                     output_field=models.BooleanField() | ||||
|                 ), | ||||
|                 name='%(app_label)s_price_neq_500_wrap', | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -61,6 +61,13 @@ class CheckConstraintTests(TestCase): | ||||
|             "<CheckConstraint: check='{}' name='{}'>".format(check, name), | ||||
|         ) | ||||
|  | ||||
|     def test_invalid_check_types(self): | ||||
|         msg = ( | ||||
|             'CheckConstraint.check must be a Q instance or boolean expression.' | ||||
|         ) | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             models.CheckConstraint(check=models.F('discounted_price'), name='check') | ||||
|  | ||||
|     def test_deconstruction(self): | ||||
|         check = models.Q(price__gt=models.F('discounted_price')) | ||||
|         name = 'price_gt_discounted_price' | ||||
| @@ -76,11 +83,25 @@ class CheckConstraintTests(TestCase): | ||||
|         with self.assertRaises(IntegrityError): | ||||
|             Product.objects.create(price=10, discounted_price=20) | ||||
|  | ||||
|     @skipUnlessDBFeature('supports_table_check_constraints') | ||||
|     def test_database_constraint_expression(self): | ||||
|         Product.objects.create(price=999, discounted_price=5) | ||||
|         with self.assertRaises(IntegrityError): | ||||
|             Product.objects.create(price=1000, discounted_price=5) | ||||
|  | ||||
|     @skipUnlessDBFeature('supports_table_check_constraints') | ||||
|     def test_database_constraint_expressionwrapper(self): | ||||
|         Product.objects.create(price=499, discounted_price=5) | ||||
|         with self.assertRaises(IntegrityError): | ||||
|             Product.objects.create(price=500, discounted_price=5) | ||||
|  | ||||
|     @skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints') | ||||
|     def test_name(self): | ||||
|         constraints = get_constraints(Product._meta.db_table) | ||||
|         for expected_name in ( | ||||
|             'price_gt_discounted_price', | ||||
|             'constraints_price_lt_1000_raw', | ||||
|             'constraints_price_neq_500_wrap', | ||||
|             'constraints_product_price_gt_0', | ||||
|         ): | ||||
|             with self.subTest(expected_name): | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| from datetime import datetime | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db.models import CharField, F, Q | ||||
| from django.db.models.expressions import Col | ||||
| from django.db.models import BooleanField, CharField, F, Q | ||||
| from django.db.models.expressions import Col, Func | ||||
| from django.db.models.fields.related_lookups import RelatedIsNull | ||||
| from django.db.models.functions import Lower | ||||
| from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan | ||||
| @@ -129,3 +129,18 @@ class TestQuery(SimpleTestCase): | ||||
|         name_exact = where.children[0] | ||||
|         self.assertIsInstance(name_exact, Exact) | ||||
|         self.assertEqual(name_exact.rhs, "['a', 'b']") | ||||
|  | ||||
|     def test_filter_conditional(self): | ||||
|         query = Query(Item) | ||||
|         where = query.build_where(Func(output_field=BooleanField())) | ||||
|         exact = where.children[0] | ||||
|         self.assertIsInstance(exact, Exact) | ||||
|         self.assertIsInstance(exact.lhs, Func) | ||||
|         self.assertIs(exact.rhs, True) | ||||
|  | ||||
|     def test_filter_conditional_join(self): | ||||
|         query = Query(Item) | ||||
|         filter_expr = Func('note__note', output_field=BooleanField()) | ||||
|         msg = 'Joined field references are not permitted in this query' | ||||
|         with self.assertRaisesMessage(FieldError, msg): | ||||
|             query.build_where(filter_expr) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user