mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #27718 -- Added QuerySet.union(), intersection(), difference().
Thanks Mariusz Felisiak for review and Oracle assistance. Thanks Tim Graham for review and writing docs.
This commit is contained in:
committed by
Tim Graham
parent
611ef422b1
commit
84c1826ded
@@ -221,6 +221,12 @@ class BaseDatabaseFeatures(object):
|
||||
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
|
||||
for_update_after_from = False
|
||||
|
||||
# Combinatorial flags
|
||||
supports_select_union = True
|
||||
supports_select_intersection = True
|
||||
supports_select_difference = True
|
||||
supports_slicing_ordering_in_compound = False
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
|
@@ -29,6 +29,11 @@ class BaseDatabaseOperations(object):
|
||||
'PositiveSmallIntegerField': (0, 32767),
|
||||
'PositiveIntegerField': (0, 2147483647),
|
||||
}
|
||||
set_operators = {
|
||||
'union': 'UNION',
|
||||
'intersection': 'INTERSECT',
|
||||
'difference': 'EXCEPT',
|
||||
}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
@@ -29,6 +29,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_column_check_constraints = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_select_intersection = False
|
||||
supports_select_difference = False
|
||||
supports_slicing_ordering_in_compound = True
|
||||
|
||||
@cached_property
|
||||
def _mysql_storage_engine(self):
|
||||
|
@@ -41,6 +41,10 @@ BEGIN
|
||||
END;
|
||||
/"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DatabaseOperations, self).__init__(*args, **kwargs)
|
||||
self.set_operators['difference'] = 'MINUS'
|
||||
|
||||
def autoinc_sql(self, table, column):
|
||||
# To simulate auto-incrementing primary keys in Oracle, we have to
|
||||
# create a sequence and a trigger.
|
||||
|
@@ -31,6 +31,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
greatest_least_ignores_nulls = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_skip_locked(self):
|
||||
|
@@ -816,6 +816,33 @@ class QuerySet(object):
|
||||
else:
|
||||
return self._filter_or_exclude(None, **filter_obj)
|
||||
|
||||
def _combinator_query(self, combinator, *other_qs, **kwargs):
|
||||
# Clone the query to inherit the select list and everything
|
||||
clone = self._clone()
|
||||
# Clear limits and ordering so they can be reapplied
|
||||
clone.query.clear_ordering(True)
|
||||
clone.query.clear_limits()
|
||||
clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs)
|
||||
clone.query.combinator = combinator
|
||||
clone.query.combinator_all = kwargs.pop('all', False)
|
||||
return clone
|
||||
|
||||
def union(self, *other_qs, **kwargs):
|
||||
if kwargs:
|
||||
unexpected_kwarg = next((k for k in kwargs.keys() if k != 'all'), None)
|
||||
if unexpected_kwarg:
|
||||
raise TypeError(
|
||||
"union() received an unexpected keyword argument '%s'" %
|
||||
(unexpected_kwarg,)
|
||||
)
|
||||
return self._combinator_query('union', *other_qs, **kwargs)
|
||||
|
||||
def intersection(self, *other_qs):
|
||||
return self._combinator_query('intersection', *other_qs)
|
||||
|
||||
def difference(self, *other_qs):
|
||||
return self._combinator_query('difference', *other_qs)
|
||||
|
||||
def select_for_update(self, nowait=False, skip_locked=False):
|
||||
"""
|
||||
Returns a new QuerySet instance that will select objects with a
|
||||
|
@@ -309,6 +309,21 @@ class SQLCompiler(object):
|
||||
seen = set()
|
||||
|
||||
for expr, is_ref in order_by:
|
||||
if self.query.combinator:
|
||||
src = expr.get_source_expressions()[0]
|
||||
# Relabel order by columns to raw numbers if this is a combined
|
||||
# query; necessary since the columns can't be referenced by the
|
||||
# fully qualified name and the simple column names may collide.
|
||||
for idx, (sel_expr, _, col_alias) in enumerate(self.select):
|
||||
if is_ref and col_alias == src.refs:
|
||||
src = src.source
|
||||
elif col_alias:
|
||||
continue
|
||||
if src == sel_expr:
|
||||
expr.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
|
||||
break
|
||||
else:
|
||||
raise DatabaseError('ORDER BY term does not match any column in the result set.')
|
||||
resolved = expr.resolve_expression(
|
||||
self.query, allow_joins=True, reuse=None)
|
||||
sql, params = self.compile(resolved)
|
||||
@@ -360,6 +375,30 @@ class SQLCompiler(object):
|
||||
return node.output_field.select_format(self, sql, params)
|
||||
return sql, params
|
||||
|
||||
def get_combinator_sql(self, combinator, all):
|
||||
features = self.connection.features
|
||||
compilers = [
|
||||
query.get_compiler(self.using, self.connection)
|
||||
for query in self.query.combined_queries
|
||||
]
|
||||
if not features.supports_slicing_ordering_in_compound:
|
||||
for query, compiler in zip(self.query.combined_queries, compilers):
|
||||
if query.low_mark or query.high_mark:
|
||||
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
|
||||
if compiler.get_order_by():
|
||||
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
|
||||
parts = (compiler.as_sql() for compiler in compilers)
|
||||
combinator_sql = self.connection.ops.set_operators[combinator]
|
||||
if all and combinator == 'union':
|
||||
combinator_sql += ' ALL'
|
||||
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}'
|
||||
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
|
||||
result = [' {} '.format(combinator_sql).join(sql_parts)]
|
||||
params = []
|
||||
for part in args_parts:
|
||||
params.extend(part)
|
||||
return result, params
|
||||
|
||||
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||
"""
|
||||
Creates the SQL for this query. Returns the SQL string and list of
|
||||
@@ -377,10 +416,19 @@ class SQLCompiler(object):
|
||||
# docstring of get_from_clause() for details.
|
||||
from_, f_params = self.get_from_clause()
|
||||
|
||||
for_update_part = None
|
||||
where, w_params = self.compile(self.where) if self.where is not None else ("", [])
|
||||
having, h_params = self.compile(self.having) if self.having is not None else ("", [])
|
||||
params = []
|
||||
|
||||
combinator = self.query.combinator
|
||||
features = self.connection.features
|
||||
if combinator:
|
||||
if not getattr(features, 'supports_select_{}'.format(combinator)):
|
||||
raise DatabaseError('{} not supported on this database backend.'.format(combinator))
|
||||
result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
|
||||
else:
|
||||
result = ['SELECT']
|
||||
params = []
|
||||
|
||||
if self.query.distinct:
|
||||
result.append(self.connection.ops.distinct_sql(distinct_fields))
|
||||
@@ -402,16 +450,15 @@ class SQLCompiler(object):
|
||||
result.extend(from_)
|
||||
params.extend(f_params)
|
||||
|
||||
for_update_part = None
|
||||
if self.query.select_for_update and self.connection.features.has_select_for_update:
|
||||
if self.connection.get_autocommit():
|
||||
raise TransactionManagementError("select_for_update cannot be used outside of a transaction.")
|
||||
raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')
|
||||
|
||||
nowait = self.query.select_for_update_nowait
|
||||
skip_locked = self.query.select_for_update_skip_locked
|
||||
# If it's a NOWAIT/SKIP LOCKED query but the backend doesn't
|
||||
# support it, raise a DatabaseError to prevent a possible
|
||||
# deadlock.
|
||||
# If it's a NOWAIT/SKIP LOCKED query but the backend
|
||||
# doesn't support it, raise a DatabaseError to prevent a
|
||||
# possible deadlock.
|
||||
if nowait and not self.connection.features.has_select_for_update_nowait:
|
||||
raise DatabaseError('NOWAIT is not supported on this database backend.')
|
||||
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
|
||||
@@ -431,8 +478,7 @@ class SQLCompiler(object):
|
||||
params.extend(g_params)
|
||||
if grouping:
|
||||
if distinct_fields:
|
||||
raise NotImplementedError(
|
||||
"annotate() + distinct(fields) is not implemented.")
|
||||
raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
|
||||
if not order_by:
|
||||
order_by = self.connection.ops.force_no_ordering()
|
||||
result.append('GROUP BY %s' % ', '.join(grouping))
|
||||
|
@@ -186,6 +186,11 @@ class Query(object):
|
||||
self.annotation_select_mask = None
|
||||
self._annotation_select_cache = None
|
||||
|
||||
# Set combination attributes
|
||||
self.combinator = None
|
||||
self.combinator_all = False
|
||||
self.combined_queries = ()
|
||||
|
||||
# These are for extensions. The contents are more or less appended
|
||||
# verbatim to the appropriate clause.
|
||||
# The _extra attribute is an OrderedDict, lazily created similarly to
|
||||
@@ -303,6 +308,9 @@ class Query(object):
|
||||
# used.
|
||||
obj._annotation_select_cache = None
|
||||
obj.max_depth = self.max_depth
|
||||
obj.combinator = self.combinator
|
||||
obj.combinator_all = self.combinator_all
|
||||
obj.combined_queries = self.combined_queries
|
||||
obj._extra = self._extra.copy() if self._extra is not None else None
|
||||
if self.extra_select_mask is None:
|
||||
obj.extra_select_mask = None
|
||||
|
@@ -801,6 +801,61 @@ typically caches its results. If the data in the database might have changed
|
||||
since a ``QuerySet`` was evaluated, you can get updated results for the same
|
||||
query by calling ``all()`` on a previously evaluated ``QuerySet``.
|
||||
|
||||
``union()``
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. method:: union(*other_qs, all=False)
|
||||
|
||||
.. versionadded:: 1.11
|
||||
|
||||
Uses SQL's ``UNION`` operator to combine the results of two or more
|
||||
``QuerySet``\s. For example:
|
||||
|
||||
>>> qs1.union(qs2, qs3)
|
||||
|
||||
The ``UNION`` operator selects only distinct values by default. To allow
|
||||
duplicate values, use the ``all=True`` argument.
|
||||
|
||||
``union()``, ``intersection()``, and ``difference()`` return model instances
|
||||
of the type of the first ``QuerySet`` even if the arguments are ``QuerySet``\s
|
||||
of other models. Passing different models works as long as the ``SELECT`` list
|
||||
is the same in all ``QuerySet``\s (at least the types, the names don't matter
|
||||
as long as the types in the same order).
|
||||
|
||||
In addition, only ``LIMIT``, ``OFFSET``, and ``ORDER BY`` (i.e. slicing and
|
||||
:meth:`order_by`) are allowed on the resulting ``QuerySet``. Further, databases
|
||||
place restrictions on what operations are allowed in the combined queries. For
|
||||
example, most databases don't allow ``LIMIT`` or ``OFFSET`` in the combined
|
||||
queries.
|
||||
|
||||
``intersection()``
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. method:: intersection(*other_qs)
|
||||
|
||||
.. versionadded:: 1.11
|
||||
|
||||
Uses SQL's ``INTERSECT`` operator to return the shared elements of two or more
|
||||
``QuerySet``\s. For example:
|
||||
|
||||
>>> qs1.itersect(qs2, qs3)
|
||||
|
||||
See :meth:`union` for some restrictions.
|
||||
|
||||
``difference()``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. method:: difference(*other_qs)
|
||||
|
||||
.. versionadded:: 1.11
|
||||
|
||||
Uses SQL's ``EXCEPT`` operator to keep only elements present in the
|
||||
``QuerySet`` but not in some other ``QuerySet``\s. For example::
|
||||
|
||||
>>> qs1.difference(qs2, qs3)
|
||||
|
||||
See :meth:`union` for some restrictions.
|
||||
|
||||
``select_related()``
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@@ -386,6 +386,9 @@ Models
|
||||
* The new ``F`` expression ``bitleftshift()`` and ``bitrightshift()`` methods
|
||||
allow :ref:`bitwise shift operations <using-f-expressions-in-filters>`.
|
||||
|
||||
* Added :meth:`.QuerySet.union`, :meth:`~.QuerySet.intersection`, and
|
||||
:meth:`~.QuerySet.difference`.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@@ -589,6 +589,9 @@ class ManagerTest(SimpleTestCase):
|
||||
'_insert',
|
||||
'_update',
|
||||
'raw',
|
||||
'union',
|
||||
'intersection',
|
||||
'difference',
|
||||
]
|
||||
|
||||
def test_manager_methods(self):
|
||||
|
111
tests/queries/test_qs_combinators.py
Normal file
111
tests/queries/test_qs_combinators.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db.models import F, IntegerField, Value
|
||||
from django.db.utils import DatabaseError
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.utils.six.moves import range
|
||||
|
||||
from .models import Number, ReservedName
|
||||
|
||||
|
||||
@skipUnlessDBFeature('supports_select_union')
|
||||
class QuerySetSetOperationTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
Number.objects.bulk_create(Number(num=i) for i in range(10))
|
||||
|
||||
def number_transform(self, value):
|
||||
return value.num
|
||||
|
||||
def assertNumbersEqual(self, queryset, expected_numbers, ordered=True):
|
||||
self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered)
|
||||
|
||||
def test_simple_union(self):
|
||||
qs1 = Number.objects.filter(num__lte=1)
|
||||
qs2 = Number.objects.filter(num__gte=8)
|
||||
qs3 = Number.objects.filter(num=5)
|
||||
self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False)
|
||||
|
||||
@skipUnlessDBFeature('supports_select_intersection')
|
||||
def test_simple_intersection(self):
|
||||
qs1 = Number.objects.filter(num__lte=5)
|
||||
qs2 = Number.objects.filter(num__gte=5)
|
||||
qs3 = Number.objects.filter(num__gte=4, num__lte=6)
|
||||
self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False)
|
||||
|
||||
@skipUnlessDBFeature('supports_select_difference')
|
||||
def test_simple_difference(self):
|
||||
qs1 = Number.objects.filter(num__lte=5)
|
||||
qs2 = Number.objects.filter(num__lte=4)
|
||||
self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False)
|
||||
|
||||
def test_union_distinct(self):
|
||||
qs1 = Number.objects.all()
|
||||
qs2 = Number.objects.all()
|
||||
self.assertEqual(len(list(qs1.union(qs2, all=True))), 20)
|
||||
self.assertEqual(len(list(qs1.union(qs2))), 10)
|
||||
|
||||
def test_union_bad_kwarg(self):
|
||||
qs1 = Number.objects.all()
|
||||
msg = "union() received an unexpected keyword argument 'bad'"
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
self.assertEqual(len(list(qs1.union(qs1, bad=True))), 20)
|
||||
|
||||
def test_limits(self):
|
||||
qs1 = Number.objects.all()
|
||||
qs2 = Number.objects.all()
|
||||
self.assertEqual(len(list(qs1.union(qs2)[:2])), 2)
|
||||
|
||||
def test_ordering(self):
|
||||
qs1 = Number.objects.filter(num__lte=1)
|
||||
qs2 = Number.objects.filter(num__gte=2, num__lte=3)
|
||||
self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0])
|
||||
|
||||
@skipUnlessDBFeature('supports_slicing_ordering_in_compound')
|
||||
def test_ordering_subqueries(self):
|
||||
qs1 = Number.objects.order_by('num')[:2]
|
||||
qs2 = Number.objects.order_by('-num')[:2]
|
||||
self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0])
|
||||
|
||||
@skipIfDBFeature('supports_slicing_ordering_in_compound')
|
||||
def test_unsupported_ordering_slicing_raises_db_error(self):
|
||||
qs1 = Number.objects.all()
|
||||
qs2 = Number.objects.all()
|
||||
msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements'
|
||||
with self.assertRaisesMessage(DatabaseError, msg):
|
||||
list(qs1.union(qs2[:10]))
|
||||
msg = 'ORDER BY not allowed in subqueries of compound statements'
|
||||
with self.assertRaisesMessage(DatabaseError, msg):
|
||||
list(qs1.order_by('id').union(qs2))
|
||||
|
||||
@skipIfDBFeature('supports_select_intersection')
|
||||
def test_unsupported_intersection_raises_db_error(self):
|
||||
qs1 = Number.objects.all()
|
||||
qs2 = Number.objects.all()
|
||||
msg = 'intersection not supported on this database backend'
|
||||
with self.assertRaisesMessage(DatabaseError, msg):
|
||||
list(qs1.intersection(qs2))
|
||||
|
||||
def test_combining_multiple_models(self):
|
||||
ReservedName.objects.create(name='99 little bugs', order=99)
|
||||
qs1 = Number.objects.filter(num=1).values_list('num', flat=True)
|
||||
qs2 = ReservedName.objects.values_list('order')
|
||||
self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99])
|
||||
|
||||
def test_order_raises_on_non_selected_column(self):
|
||||
qs1 = Number.objects.filter().annotate(
|
||||
annotation=Value(1, IntegerField()),
|
||||
).values('annotation', num2=F('num'))
|
||||
qs2 = Number.objects.filter().values('id', 'num')
|
||||
# Should not raise
|
||||
list(qs1.union(qs2).order_by('annotation'))
|
||||
list(qs1.union(qs2).order_by('num2'))
|
||||
msg = 'ORDER BY term does not match any column in the result set'
|
||||
# 'id' is not part of the select
|
||||
with self.assertRaisesMessage(DatabaseError, msg):
|
||||
list(qs1.union(qs2).order_by('id'))
|
||||
# 'num' got realiased to num2
|
||||
with self.assertRaisesMessage(DatabaseError, msg):
|
||||
list(qs1.union(qs2).order_by('num'))
|
||||
# switched order, now 'exists' again:
|
||||
list(qs2.union(qs1).order_by('num'))
|
Reference in New Issue
Block a user