From 30a01441347d5a2146af2944b29778fa0834d4be Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Mon, 17 Jan 2022 18:01:07 +0100 Subject: [PATCH] Fixed #29338 -- Allowed using combined queryset in Subquery. Thanks Eugene Kovalev for the initial patch, Simon Charette for the review, and Chetan Khanna for help. --- django/db/models/sql/compiler.py | 9 +++++-- django/db/models/sql/query.py | 6 +++++ tests/queries/test_qs_combinators.py | 37 ++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 69a2d9298f..928ab40254 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -503,7 +503,10 @@ class SQLCompiler: part_sql = 'SELECT * FROM ({})'.format(part_sql) # Add parentheses when combining with compound query if not # already added for all compound queries. - elif not features.supports_slicing_ordering_in_compound: + elif ( + self.query.subquery or + not features.supports_slicing_ordering_in_compound + ): part_sql = '({})'.format(part_sql) parts += ((part_sql, part_args),) except EmptyResultSet: @@ -517,7 +520,9 @@ class SQLCompiler: combinator_sql = self.connection.ops.set_operators[combinator] if all and combinator == 'union': combinator_sql += ' ALL' - braces = '({})' if features.supports_slicing_ordering_in_compound else '{}' + braces = '{}' + if not self.query.subquery and features.supports_slicing_ordering_in_compound: + braces = '({})' sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts)) result = [' {} '.format(combinator_sql).join(sql_parts)] params = [] diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b13c7b6893..e3fdea6f3a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1048,6 +1048,12 @@ class Query(BaseExpression): clone.bump_prefix(query) clone.subquery = True clone.where.resolve_expression(query, *args, **kwargs) + # Resolve combined queries. + if clone.combinator: + clone.combined_queries = tuple([ + combined_query.resolve_expression(query, *args, **kwargs) + for combined_query in clone.combined_queries + ]) for key, value in clone.annotations.items(): resolved = value.resolve_expression(query, *args, **kwargs) if hasattr(resolved, 'external_aliases'): diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py index 89d08ef6db..37ce37c2c1 100644 --- a/tests/queries/test_qs_combinators.py +++ b/tests/queries/test_qs_combinators.py @@ -1,11 +1,11 @@ import operator from django.db import DatabaseError, NotSupportedError, connection -from django.db.models import Exists, F, IntegerField, OuterRef, Value +from django.db.models import Exists, F, IntegerField, OuterRef, Subquery, Value from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext -from .models import Celebrity, Number, ReservedName +from .models import Author, Celebrity, ExtraInfo, Number, ReservedName @skipUnlessDBFeature('supports_select_union') @@ -252,6 +252,39 @@ class QuerySetSetOperationTests(TestCase): [reserved_name.pk], ) + def test_union_in_subquery(self): + ReservedName.objects.bulk_create([ + ReservedName(name='rn1', order=8), + ReservedName(name='rn2', order=1), + ReservedName(name='rn3', order=5), + ]) + qs1 = Number.objects.filter(num__gt=7, num=OuterRef('order')) + qs2 = Number.objects.filter(num__lt=2, num=OuterRef('order')) + self.assertCountEqual( + ReservedName.objects.annotate( + number=Subquery(qs1.union(qs2).values('num')), + ).filter(number__isnull=False).values_list('order', flat=True), + [8, 1], + ) + + def test_union_in_subquery_related_outerref(self): + e1 = ExtraInfo.objects.create(value=7, info='e3') + e2 = ExtraInfo.objects.create(value=5, info='e2') + e3 = ExtraInfo.objects.create(value=1, info='e1') + Author.objects.bulk_create([ + Author(name='a1', num=1, extra=e1), + Author(name='a2', num=3, extra=e2), + Author(name='a3', num=2, extra=e3), + ]) + qs1 = ExtraInfo.objects.order_by().filter(value=OuterRef('num')) + qs2 = ExtraInfo.objects.order_by().filter(value__lt=OuterRef('extra__value')) + qs = Author.objects.annotate( + info=Subquery(qs1.union(qs2).values('info')[:1]), + ).filter(info__isnull=False).values_list('name', flat=True) + self.assertCountEqual(qs, ['a1', 'a2']) + # Combined queries don't mutate. + self.assertCountEqual(qs, ['a1', 'a2']) + def test_count_union(self): qs1 = Number.objects.filter(num__lte=1).values('num') qs2 = Number.objects.filter(num__gte=2, num__lte=3).values('num')