mirror of
https://github.com/django/django.git
synced 2025-03-31 19:46:42 +00:00
Fixed #31910 -- Fixed crash of GIS aggregations over subqueries.
Regression was introduced by fff5186 but was due a long standing issue. AggregateQuery was abusing Query.subquery: bool by stashing its compiled inner query's SQL for later use in its compiler which made select_format checks for Query.subquery wrongly assume the provide query was a subquery. This patch prevents that from happening by using a dedicated inner_query attribute which is compiled at a later time by SQLAggregateCompiler. Moving the inner query's compilation to SQLAggregateCompiler.compile had the side effect of addressing a long standing issue with aggregation subquery pushdown which prevented converters from being run. This is now fixed as the aggregation_regress adjustments demonstrate. Refs #25367. Thanks Eran Keydar for the report.
This commit is contained in:
parent
789c47e6de
commit
c2d4926702
@ -1596,8 +1596,11 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||||||
sql = ', '.join(sql)
|
sql = ', '.join(sql)
|
||||||
params = tuple(params)
|
params = tuple(params)
|
||||||
|
|
||||||
sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
|
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
|
||||||
params = params + self.query.sub_params
|
self.using
|
||||||
|
).as_sql(with_col_aliases=True)
|
||||||
|
sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
|
||||||
|
params = params + inner_query_params
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,9 +17,7 @@ from collections.abc import Iterator, Mapping
|
|||||||
from itertools import chain, count, product
|
from itertools import chain, count, product
|
||||||
from string import ascii_uppercase
|
from string import ascii_uppercase
|
||||||
|
|
||||||
from django.core.exceptions import (
|
from django.core.exceptions import FieldDoesNotExist, FieldError
|
||||||
EmptyResultSet, FieldDoesNotExist, FieldError,
|
|
||||||
)
|
|
||||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
||||||
from django.db.models.aggregates import Count
|
from django.db.models.aggregates import Count
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
@ -449,8 +447,9 @@ class Query(BaseExpression):
|
|||||||
if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
|
if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
|
||||||
self.distinct or self.combinator):
|
self.distinct or self.combinator):
|
||||||
from django.db.models.sql.subqueries import AggregateQuery
|
from django.db.models.sql.subqueries import AggregateQuery
|
||||||
outer_query = AggregateQuery(self.model)
|
|
||||||
inner_query = self.clone()
|
inner_query = self.clone()
|
||||||
|
inner_query.subquery = True
|
||||||
|
outer_query = AggregateQuery(self.model, inner_query)
|
||||||
inner_query.select_for_update = False
|
inner_query.select_for_update = False
|
||||||
inner_query.select_related = False
|
inner_query.select_related = False
|
||||||
inner_query.set_annotation_mask(self.annotation_select)
|
inner_query.set_annotation_mask(self.annotation_select)
|
||||||
@ -492,13 +491,6 @@ class Query(BaseExpression):
|
|||||||
# field selected in the inner query, yet we must use a subquery.
|
# field selected in the inner query, yet we must use a subquery.
|
||||||
# So, make sure at least one field is selected.
|
# So, make sure at least one field is selected.
|
||||||
inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
|
inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
|
||||||
try:
|
|
||||||
outer_query.add_subquery(inner_query, using)
|
|
||||||
except EmptyResultSet:
|
|
||||||
return {
|
|
||||||
alias: None
|
|
||||||
for alias in outer_query.annotation_select
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
outer_query = self
|
outer_query = self
|
||||||
self.select = ()
|
self.select = ()
|
||||||
|
@ -157,6 +157,6 @@ class AggregateQuery(Query):
|
|||||||
|
|
||||||
compiler = 'SQLAggregateCompiler'
|
compiler = 'SQLAggregateCompiler'
|
||||||
|
|
||||||
def add_subquery(self, query, using):
|
def __init__(self, model, inner_query):
|
||||||
query.subquery = True
|
self.inner_query = inner_query
|
||||||
self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)
|
super().__init__(model)
|
||||||
|
@ -974,7 +974,7 @@ class AggregationTests(TestCase):
|
|||||||
def test_empty_filter_aggregate(self):
|
def test_empty_filter_aggregate(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")),
|
Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")),
|
||||||
{"pk__count": None}
|
{"pk__count": 0}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_none_call_before_aggregate(self):
|
def test_none_call_before_aggregate(self):
|
||||||
|
@ -12,6 +12,7 @@ from django.core.management import call_command
|
|||||||
from django.db import DatabaseError, NotSupportedError, connection
|
from django.db import DatabaseError, NotSupportedError, connection
|
||||||
from django.db.models import F, OuterRef, Subquery
|
from django.db.models import F, OuterRef, Subquery
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite,
|
mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite,
|
||||||
@ -593,6 +594,19 @@ class GeoQuerySetTest(TestCase):
|
|||||||
qs = City.objects.filter(name='NotACity')
|
qs = City.objects.filter(name='NotACity')
|
||||||
self.assertIsNone(qs.aggregate(Union('point'))['point__union'])
|
self.assertIsNone(qs.aggregate(Union('point'))['point__union'])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_union_aggr')
|
||||||
|
def test_geoagg_subquery(self):
|
||||||
|
ks = State.objects.get(name='Kansas')
|
||||||
|
union = GEOSGeometry('MULTIPOINT(-95.235060 38.971823)')
|
||||||
|
# Use distinct() to force the usage of a subquery for aggregation.
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
self.assertIs(union.equals(
|
||||||
|
City.objects.filter(point__within=ks.poly).distinct().aggregate(
|
||||||
|
Union('point'),
|
||||||
|
)['point__union'],
|
||||||
|
), True)
|
||||||
|
self.assertIn('subquery', ctx.captured_queries[0]['sql'])
|
||||||
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
connection.vendor == 'oracle',
|
connection.vendor == 'oracle',
|
||||||
'Oracle supports tolerance parameter.',
|
'Oracle supports tolerance parameter.',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user