1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #26430 -- Fixed coalesced aggregation of empty result sets.

Disable the EmptyResultSet optimization when performing aggregation as
it might interfere with coalescence.
This commit is contained in:
Simon Charette 2021-05-21 21:48:46 -04:00 committed by Mariusz Felisiak
parent fde6fb2898
commit f3112fde98
3 changed files with 48 additions and 9 deletions

View File

@ -26,10 +26,13 @@ class SQLCompiler:
re.MULTILINE | re.DOTALL, re.MULTILINE | re.DOTALL,
) )
def __init__(self, query, connection, using): def __init__(self, query, connection, using, elide_empty=True):
self.query = query self.query = query
self.connection = connection self.connection = connection
self.using = using self.using = using
# Some queries, e.g. coalesced aggregation, need to be executed even if
# they would return an empty result set.
self.elide_empty = elide_empty
self.quote_cache = {'*': '*'} self.quote_cache = {'*': '*'}
# The select, klass_info, and annotations are needed by QuerySet.iterator() # The select, klass_info, and annotations are needed by QuerySet.iterator()
# these are set as a side-effect of executing the query. Note that we calculate # these are set as a side-effect of executing the query. Note that we calculate
@ -458,7 +461,7 @@ class SQLCompiler:
def get_combinator_sql(self, combinator, all): def get_combinator_sql(self, combinator, all):
features = self.connection.features features = self.connection.features
compilers = [ compilers = [
query.get_compiler(self.using, self.connection) query.get_compiler(self.using, self.connection, self.elide_empty)
for query in self.query.combined_queries if not query.is_empty() for query in self.query.combined_queries if not query.is_empty()
] ]
if not features.supports_slicing_ordering_in_compound: if not features.supports_slicing_ordering_in_compound:
@ -535,7 +538,13 @@ class SQLCompiler:
# This must come after 'select', 'ordering', and 'distinct' # This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details). # (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause() from_, f_params = self.get_from_clause()
where, w_params = self.compile(self.where) if self.where is not None else ("", []) try:
where, w_params = self.compile(self.where) if self.where is not None else ('', [])
except EmptyResultSet:
if self.elide_empty:
raise
# Use a predicate that's always False.
where, w_params = '0 = 1', []
having, h_params = self.compile(self.having) if self.having is not None else ("", []) having, h_params = self.compile(self.having) if self.having is not None else ("", [])
result = ['SELECT'] result = ['SELECT']
params = [] params = []
@ -1652,7 +1661,7 @@ class SQLAggregateCompiler(SQLCompiler):
params = tuple(params) params = tuple(params)
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
self.using self.using, elide_empty=self.elide_empty,
).as_sql(with_col_aliases=True) ).as_sql(with_col_aliases=True)
sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql) sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
params = params + inner_query_params params = params + inner_query_params

View File

@ -273,12 +273,12 @@ class Query(BaseExpression):
memo[id(self)] = result memo[id(self)] = result
return result return result
def get_compiler(self, using=None, connection=None): def get_compiler(self, using=None, connection=None, elide_empty=True):
if using is None and connection is None: if using is None and connection is None:
raise ValueError("Need either using or connection") raise ValueError("Need either using or connection")
if using: if using:
connection = connections[using] connection = connections[using]
return connection.ops.compiler(self.compiler)(self, connection, using) return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty)
def get_meta(self): def get_meta(self):
""" """
@ -494,10 +494,8 @@ class Query(BaseExpression):
outer_query.clear_limits() outer_query.clear_limits()
outer_query.select_for_update = False outer_query.select_for_update = False
outer_query.select_related = False outer_query.select_related = False
compiler = outer_query.get_compiler(using) compiler = outer_query.get_compiler(using, elide_empty=False)
result = compiler.execute_sql(SINGLE) result = compiler.execute_sql(SINGLE)
if result is None:
result = [None] * len(outer_query.annotation_select)
converters = compiler.get_converters(outer_query.annotation_select.values()) converters = compiler.get_converters(outer_query.annotation_select.values())
result = next(compiler.apply_converters((result,), converters)) result = next(compiler.apply_converters((result,), converters))

View File

@ -8,6 +8,7 @@ from django.db.models import (
Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When, IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
) )
from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, Greatest from django.db.models.functions import Coalesce, Greatest
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipUnlessDBFeature
@ -1340,3 +1341,34 @@ class AggregateTestCase(TestCase):
('Stuart Russell', 1), ('Stuart Russell', 1),
('Peter Norvig', 2), ('Peter Norvig', 2),
], lambda a: (a.name, a.contact_count), ordered=False) ], lambda a: (a.name, a.contact_count), ordered=False)
def test_coalesced_empty_result_set(self):
self.assertEqual(
Publisher.objects.none().aggregate(
sum_awards=Coalesce(Sum('num_awards'), 0),
)['sum_awards'],
0,
)
# Multiple expressions.
self.assertEqual(
Publisher.objects.none().aggregate(
sum_awards=Coalesce(Sum('num_awards'), None, 0),
)['sum_awards'],
0,
)
# Nested coalesce.
self.assertEqual(
Publisher.objects.none().aggregate(
sum_awards=Coalesce(Coalesce(Sum('num_awards'), None), 0),
)['sum_awards'],
0,
)
# Expression coalesce.
self.assertIsInstance(
Store.objects.none().aggregate(
latest_opening=Coalesce(
Max('original_opening'), RawSQL('CURRENT_TIMESTAMP', []),
),
)['latest_opening'],
datetime.datetime,
)