From 76e37513e22f4d9a01c7f15eee36fe44388e6670 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sun, 6 Nov 2022 11:19:33 -0500 Subject: [PATCH] Refs #33374 -- Adjusted full match condition handling. Adjusting WhereNode.as_sql() to raise an exception when encoutering a full match just like with empty matches ensures that all case are explicitly handled. --- django/core/exceptions.py | 6 +++++ django/db/backends/mysql/compiler.py | 14 ++++++---- django/db/models/aggregates.py | 9 ++++--- django/db/models/expressions.py | 17 +++++------- django/db/models/fields/__init__.py | 9 ------- django/db/models/sql/compiler.py | 37 +++++++++++++++++--------- django/db/models/sql/datastructures.py | 8 ++++-- django/db/models/sql/where.py | 25 +++++++++-------- docs/ref/exceptions.txt | 11 ++++++++ tests/annotations/tests.py | 19 +++++++++++-- tests/queries/tests.py | 20 +++++++++----- 11 files changed, 114 insertions(+), 61 deletions(-) diff --git a/django/core/exceptions.py b/django/core/exceptions.py index 7be4e16bc5..646644f3e0 100644 --- a/django/core/exceptions.py +++ b/django/core/exceptions.py @@ -233,6 +233,12 @@ class EmptyResultSet(Exception): pass +class FullResultSet(Exception): + """A database query predicate is matches everything.""" + + pass + + class SynchronousOnlyOperation(Exception): """The user tried to call a sync-only function from an async context.""" diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index bd2715fb43..2ec6bea2f1 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -1,4 +1,4 @@ -from django.core.exceptions import FieldError +from django.core.exceptions import FieldError, FullResultSet from django.db.models.expressions import Col from django.db.models.sql import compiler @@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): "DELETE %s FROM" % self.quote_name_unless_alias(self.query.get_initial_alias()) ] - from_sql, from_params = self.get_from_clause() + from_sql, params = self.get_from_clause() result.extend(from_sql) - where_sql, where_params = self.compile(where) - if where_sql: + try: + where_sql, where_params = self.compile(where) + except FullResultSet: + pass + else: result.append("WHERE %s" % where_sql) - return " ".join(result), tuple(from_params) + tuple(where_params) + params.extend(where_params) + return " ".join(result), tuple(params) class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index ab38a33bf0..7878fb6fb2 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -1,7 +1,7 @@ """ Classes to represent the definitions of aggregate functions. """ -from django.core.exceptions import FieldError +from django.core.exceptions import FieldError, FullResultSet from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import IntegerField from django.db.models.functions.comparison import Coalesce @@ -104,8 +104,11 @@ class Aggregate(Func): extra_context["distinct"] = "DISTINCT " if self.distinct else "" if self.filter: if connection.features.supports_aggregate_filter_clause: - filter_sql, filter_params = self.filter.as_sql(compiler, connection) - if filter_sql: + try: + filter_sql, filter_params = self.filter.as_sql(compiler, connection) + except FullResultSet: + pass + else: template = self.filter_template % extra_context.get( "template", self.template ) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 8b04e1f11b..86a3a92f07 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -7,7 +7,7 @@ from collections import defaultdict from decimal import Decimal from uuid import UUID -from django.core.exceptions import EmptyResultSet, FieldError +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError, connection from django.db.models import fields from django.db.models.constants import LOOKUP_SEP @@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression): if empty_result_set_value is NotImplemented: raise arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) + except FullResultSet: + arg_sql, arg_params = compiler.compile(Value(True)) sql_parts.append(arg_sql) params.extend(arg_params) data = {**self.extra, **extra_context} @@ -1367,14 +1369,6 @@ class When(Expression): template_params = extra_context sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) - # Filters that match everything are handled as empty strings in the - # WHERE clause, but in a CASE WHEN expression they must use a predicate - # that's always True. - if condition_sql == "": - if connection.features.supports_boolean_expr_in_select_clause: - condition_sql, condition_params = compiler.compile(Value(True)) - else: - condition_sql, condition_params = "1=1", () template_params["condition"] = condition_sql result_sql, result_params = compiler.compile(self.result) template_params["result"] = result_sql @@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression): template_params = {**self.extra, **extra_context} case_parts = [] sql_params = [] + default_sql, default_params = compiler.compile(self.default) for case in self.cases: try: case_sql, case_params = compiler.compile(case) except EmptyResultSet: continue + except FullResultSet: + default_sql, default_params = compiler.compile(case.result) + break case_parts.append(case_sql) sql_params.extend(case_params) - default_sql, default_params = compiler.compile(self.default) if not case_parts: return default_sql, default_params case_joiner = case_joiner or self.case_joiner diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 5069a491e8..2a98396cad 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -1103,15 +1103,6 @@ class BooleanField(Field): defaults = {"form_class": form_class, "required": False} return super().formfield(**{**defaults, **kwargs}) - def select_format(self, compiler, sql, params): - sql, params = super().select_format(compiler, sql, params) - # Filters that match everything are handled as empty strings in the - # WHERE clause, but in SELECT or GROUP BY list they must use a - # predicate that's always True. - if sql == "": - sql = "1" - return sql, params - class CharField(Field): description = _("String (up to %(max_length)s)") diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 97c7ba2013..170bde1d42 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -4,7 +4,7 @@ import re from functools import partial from itertools import chain -from django.core.exceptions import EmptyResultSet, FieldError +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value @@ -169,7 +169,7 @@ class SQLCompiler: expr = Ref(alias, expr) try: sql, params = self.compile(expr) - except EmptyResultSet: + except (EmptyResultSet, FullResultSet): continue sql, params = expr.select_format(self, sql, params) params_hash = make_hashable(params) @@ -287,6 +287,8 @@ class SQLCompiler: sql, params = "0", () else: sql, params = self.compile(Value(empty_result_set_value)) + except FullResultSet: + sql, params = self.compile(Value(True)) else: sql, params = col.select_format(self, sql, params) if alias is None and with_col_aliases: @@ -721,9 +723,16 @@ class SQLCompiler: raise # Use a predicate that's always False. where, w_params = "0 = 1", [] - having, h_params = ( - self.compile(self.having) if self.having is not None else ("", []) - ) + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] result = ["SELECT"] params = [] @@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler): ) def _as_sql(self, query): - result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)] - where, params = self.compile(query.where) - if where: - result.append("WHERE %s" % where) - return " ".join(result), tuple(params) + delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table) + try: + where, params = self.compile(query.where) + except FullResultSet: + return delete, () + return f"{delete} WHERE {where}", tuple(params) def as_sql(self): """ @@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler): "UPDATE %s SET" % qn(table), ", ".join(values), ] - where, params = self.compile(self.query.where) - if where: + try: + where, params = self.compile(self.query.where) + except FullResultSet: + params = [] + else: result.append("WHERE %s" % where) return " ".join(result), tuple(update_params + params) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 1edf040e82..069eb1a301 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -2,6 +2,7 @@ Useful auxiliary data structures for query construction. Not useful outside the SQL domain. """ +from django.core.exceptions import FullResultSet from django.db.models.sql.constants import INNER, LOUTER @@ -100,8 +101,11 @@ class Join: join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if self.filtered_relation: - extra_sql, extra_params = compiler.compile(self.filtered_relation) - if extra_sql: + try: + extra_sql, extra_params = compiler.compile(self.filtered_relation) + except FullResultSet: + pass + else: join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if not join_conditions: diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 63fdf58d9d..1928ba91b8 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints. import operator from functools import reduce -from django.core.exceptions import EmptyResultSet +from django.core.exceptions import EmptyResultSet, FullResultSet from django.db.models.expressions import Case, When from django.db.models.lookups import Exact from django.utils import tree @@ -145,6 +145,8 @@ class WhereNode(tree.Node): sql, params = compiler.compile(child) except EmptyResultSet: empty_needed -= 1 + except FullResultSet: + full_needed -= 1 else: if sql: result.append(sql) @@ -158,24 +160,25 @@ class WhereNode(tree.Node): # counts. if empty_needed == 0: if self.negated: - return "", [] + raise FullResultSet else: raise EmptyResultSet if full_needed == 0: if self.negated: raise EmptyResultSet else: - return "", [] + raise FullResultSet conn = " %s " % self.connector sql_string = conn.join(result) - if sql_string: - if self.negated: - # Some backends (Oracle at least) need parentheses - # around the inner SQL in the negated case, even if the - # inner SQL contains just a single expression. - sql_string = "NOT (%s)" % sql_string - elif len(result) > 1 or self.resolved: - sql_string = "(%s)" % sql_string + if not sql_string: + raise FullResultSet + if self.negated: + # Some backends (Oracle at least) need parentheses around the inner + # SQL in the negated case, even if the inner SQL contains just a + # single expression. + sql_string = "NOT (%s)" % sql_string + elif len(result) > 1 or self.resolved: + sql_string = "(%s)" % sql_string return sql_string, result_params def get_group_by_cols(self): diff --git a/docs/ref/exceptions.txt b/docs/ref/exceptions.txt index 2b567414e6..b588d0ee81 100644 --- a/docs/ref/exceptions.txt +++ b/docs/ref/exceptions.txt @@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``. return any results. Most Django projects won't encounter this exception, but it might be useful for implementing custom lookups and expressions. +``FullResultSet`` +----------------- + +.. exception:: FullResultSet + +.. versionadded:: 4.2 + + ``FullResultSet`` may be raised during query generation if a query will + match everything. Most Django projects won't encounter this exception, but + it might be useful for implementing custom lookups and expressions. + ``FieldDoesNotExist`` --------------------- diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py index 52c15bba87..d05af552b4 100644 --- a/tests/annotations/tests.py +++ b/tests/annotations/tests.py @@ -24,7 +24,15 @@ from django.db.models import ( When, ) from django.db.models.expressions import RawSQL -from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim +from django.db.models.functions import ( + Cast, + Coalesce, + ExtractYear, + Floor, + Length, + Lower, + Trim, +) from django.test import TestCase, skipUnlessDBFeature from django.test.utils import register_lookup @@ -282,6 +290,13 @@ class NonAggregateAnnotationTestCase(TestCase): self.assertEqual(len(books), Book.objects.count()) self.assertTrue(all(book.selected for book in books)) + def test_full_expression_wrapped_annotation(self): + books = Book.objects.annotate( + selected=Coalesce(~Q(pk__in=[]), True), + ) + self.assertEqual(len(books), Book.objects.count()) + self.assertTrue(all(book.selected for book in books)) + def test_full_expression_annotation_with_aggregation(self): qs = Book.objects.filter(isbn="159059725").annotate( selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()), @@ -292,7 +307,7 @@ class NonAggregateAnnotationTestCase(TestCase): def test_aggregate_over_full_expression_annotation(self): qs = Book.objects.annotate( selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()), - ).aggregate(Sum("selected")) + ).aggregate(selected__sum=Sum(Cast("selected", IntegerField()))) self.assertEqual(qs["selected__sum"], Book.objects.count()) def test_empty_queryset_annotation(self): diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 63e9ea6687..332a6e5d9a 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -5,7 +5,7 @@ import unittest from operator import attrgetter from threading import Lock -from django.core.exceptions import EmptyResultSet, FieldError +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DEFAULT_DB_ALIAS, connection from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q from django.db.models.expressions import RawSQL @@ -3588,7 +3588,8 @@ class WhereNodeTest(SimpleTestCase): with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection) w.negate() - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", [])) w.negate() @@ -3597,7 +3598,8 @@ class WhereNodeTest(SimpleTestCase): with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection) w.negate() - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) def test_empty_full_handling_disjunction(self): compiler = WhereNodeTest.MockCompiler() @@ -3605,7 +3607,8 @@ class WhereNodeTest(SimpleTestCase): with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection) w.negate() - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR) self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", [])) w.negate() @@ -3619,7 +3622,8 @@ class WhereNodeTest(SimpleTestCase): compiler = WhereNodeTest.MockCompiler() empty_w = WhereNode() w = WhereNode(children=[empty_w, empty_w]) - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) w.negate() with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection) @@ -3627,9 +3631,11 @@ class WhereNodeTest(SimpleTestCase): with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection) w.negate() - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) w = WhereNode(children=[empty_w, NothingNode()], connector=OR) - self.assertEqual(w.as_sql(compiler, connection), ("", [])) + with self.assertRaises(FullResultSet): + w.as_sql(compiler, connection) w = WhereNode(children=[empty_w, NothingNode()], connector=AND) with self.assertRaises(EmptyResultSet): w.as_sql(compiler, connection)