From 84c1826ded17b2d74f66717fb745fc36e37949fd Mon Sep 17 00:00:00 2001 From: Florian Apolloner Date: Sat, 14 Jan 2017 14:32:07 +0100 Subject: [PATCH] Fixed #27718 -- Added QuerySet.union(), intersection(), difference(). Thanks Mariusz Felisiak for review and Oracle assistance. Thanks Tim Graham for review and writing docs. --- django/db/backends/base/features.py | 6 + django/db/backends/base/operations.py | 5 + django/db/backends/mysql/features.py | 3 + django/db/backends/oracle/operations.py | 4 + django/db/backends/postgresql/features.py | 1 + django/db/models/query.py | 27 ++++ django/db/models/sql/compiler.py | 148 ++++++++++++++-------- django/db/models/sql/query.py | 8 ++ docs/ref/models/querysets.txt | 55 ++++++++ docs/releases/1.11.txt | 3 + tests/basic/tests.py | 3 + tests/queries/test_qs_combinators.py | 111 ++++++++++++++++ 12 files changed, 323 insertions(+), 51 deletions(-) create mode 100644 tests/queries/test_qs_combinators.py diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index e207b174b5..dfcb719eae 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -221,6 +221,12 @@ class BaseDatabaseFeatures(object): # Place FOR UPDATE right after FROM clause. Used on MSSQL. for_update_after_from = False + # Combinatorial flags + supports_select_union = True + supports_select_intersection = True + supports_select_difference = True + supports_slicing_ordering_in_compound = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index a072687ebb..efa1a0b457 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -29,6 +29,11 @@ class BaseDatabaseOperations(object): 'PositiveSmallIntegerField': (0, 32767), 'PositiveIntegerField': (0, 2147483647), } + set_operators = { + 'union': 'UNION', + 'intersection': 'INTERSECT', + 'difference': 'EXCEPT', + } def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index ca8143b875..412887793c 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -29,6 +29,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_column_check_constraints = False can_clone_databases = True supports_temporal_subtraction = True + supports_select_intersection = False + supports_select_difference = False + supports_slicing_ordering_in_compound = True @cached_property def _mysql_storage_engine(self): diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index c75cadc2a6..9c382895b7 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -41,6 +41,10 @@ BEGIN END; /""" + def __init__(self, *args, **kwargs): + super(DatabaseOperations, self).__init__(*args, **kwargs) + self.set_operators['difference'] = 'MINUS' + def autoinc_sql(self, table, column): # To simulate auto-incrementing primary keys in Oracle, we have to # create a sequence and a trigger. diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index d7bf73c09d..d80393d399 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -31,6 +31,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): greatest_least_ignores_nulls = True can_clone_databases = True supports_temporal_subtraction = True + supports_slicing_ordering_in_compound = True @cached_property def has_select_for_update_skip_locked(self): diff --git a/django/db/models/query.py b/django/db/models/query.py index e059c68f13..f9534f3700 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -816,6 +816,33 @@ class QuerySet(object): else: return self._filter_or_exclude(None, **filter_obj) + def _combinator_query(self, combinator, *other_qs, **kwargs): + # Clone the query to inherit the select list and everything + clone = self._clone() + # Clear limits and ordering so they can be reapplied + clone.query.clear_ordering(True) + clone.query.clear_limits() + clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs) + clone.query.combinator = combinator + clone.query.combinator_all = kwargs.pop('all', False) + return clone + + def union(self, *other_qs, **kwargs): + if kwargs: + unexpected_kwarg = next((k for k in kwargs.keys() if k != 'all'), None) + if unexpected_kwarg: + raise TypeError( + "union() received an unexpected keyword argument '%s'" % + (unexpected_kwarg,) + ) + return self._combinator_query('union', *other_qs, **kwargs) + + def intersection(self, *other_qs): + return self._combinator_query('intersection', *other_qs) + + def difference(self, *other_qs): + return self._combinator_query('difference', *other_qs) + def select_for_update(self, nowait=False, skip_locked=False): """ Returns a new QuerySet instance that will select objects with a diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b197ab90cc..37442c06c4 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -309,6 +309,21 @@ class SQLCompiler(object): seen = set() for expr, is_ref in order_by: + if self.query.combinator: + src = expr.get_source_expressions()[0] + # Relabel order by columns to raw numbers if this is a combined + # query; necessary since the columns can't be referenced by the + # fully qualified name and the simple column names may collide. + for idx, (sel_expr, _, col_alias) in enumerate(self.select): + if is_ref and col_alias == src.refs: + src = src.source + elif col_alias: + continue + if src == sel_expr: + expr.set_source_expressions([RawSQL('%d' % (idx + 1), ())]) + break + else: + raise DatabaseError('ORDER BY term does not match any column in the result set.') resolved = expr.resolve_expression( self.query, allow_joins=True, reuse=None) sql, params = self.compile(resolved) @@ -360,6 +375,30 @@ class SQLCompiler(object): return node.output_field.select_format(self, sql, params) return sql, params + def get_combinator_sql(self, combinator, all): + features = self.connection.features + compilers = [ + query.get_compiler(self.using, self.connection) + for query in self.query.combined_queries + ] + if not features.supports_slicing_ordering_in_compound: + for query, compiler in zip(self.query.combined_queries, compilers): + if query.low_mark or query.high_mark: + raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.') + if compiler.get_order_by(): + raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.') + parts = (compiler.as_sql() for compiler in compilers) + combinator_sql = self.connection.ops.set_operators[combinator] + if all and combinator == 'union': + combinator_sql += ' ALL' + braces = '({})' if features.supports_slicing_ordering_in_compound else '{}' + sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts)) + result = [' {} '.format(combinator_sql).join(sql_parts)] + params = [] + for part in args_parts: + params.extend(part) + return result, params + def as_sql(self, with_limits=True, with_col_aliases=False): """ Creates the SQL for this query. Returns the SQL string and list of @@ -377,69 +416,76 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() + for_update_part = None where, w_params = self.compile(self.where) if self.where is not None else ("", []) having, h_params = self.compile(self.having) if self.having is not None else ("", []) - params = [] - result = ['SELECT'] - if self.query.distinct: - result.append(self.connection.ops.distinct_sql(distinct_fields)) + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr(features, 'supports_select_{}'.format(combinator)): + raise DatabaseError('{} not supported on this database backend.'.format(combinator)) + result, params = self.get_combinator_sql(combinator, self.query.combinator_all) + else: + result = ['SELECT'] + params = [] - out_cols = [] - col_idx = 1 - for _, (s_sql, s_params), alias in self.select + extra_select: - if alias: - s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) - elif with_col_aliases: - s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) - col_idx += 1 - params.extend(s_params) - out_cols.append(s_sql) + if self.query.distinct: + result.append(self.connection.ops.distinct_sql(distinct_fields)) - result.append(', '.join(out_cols)) + out_cols = [] + col_idx = 1 + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) + elif with_col_aliases: + s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) + col_idx += 1 + params.extend(s_params) + out_cols.append(s_sql) - result.append('FROM') - result.extend(from_) - params.extend(f_params) + result.append(', '.join(out_cols)) - for_update_part = None - if self.query.select_for_update and self.connection.features.has_select_for_update: - if self.connection.get_autocommit(): - raise TransactionManagementError("select_for_update cannot be used outside of a transaction.") + result.append('FROM') + result.extend(from_) + params.extend(f_params) - nowait = self.query.select_for_update_nowait - skip_locked = self.query.select_for_update_skip_locked - # If it's a NOWAIT/SKIP LOCKED query but the backend doesn't - # support it, raise a DatabaseError to prevent a possible - # deadlock. - if nowait and not self.connection.features.has_select_for_update_nowait: - raise DatabaseError('NOWAIT is not supported on this database backend.') - elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: - raise DatabaseError('SKIP LOCKED is not supported on this database backend.') - for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) + if self.query.select_for_update and self.connection.features.has_select_for_update: + if self.connection.get_autocommit(): + raise TransactionManagementError('select_for_update cannot be used outside of a transaction.') - if for_update_part and self.connection.features.for_update_after_from: - result.append(for_update_part) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + # If it's a NOWAIT/SKIP LOCKED query but the backend + # doesn't support it, raise a DatabaseError to prevent a + # possible deadlock. + if nowait and not self.connection.features.has_select_for_update_nowait: + raise DatabaseError('NOWAIT is not supported on this database backend.') + elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: + raise DatabaseError('SKIP LOCKED is not supported on this database backend.') + for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) - if where: - result.append('WHERE %s' % where) - params.extend(w_params) + if for_update_part and self.connection.features.for_update_after_from: + result.append(for_update_part) - grouping = [] - for g_sql, g_params in group_by: - grouping.append(g_sql) - params.extend(g_params) - if grouping: - if distinct_fields: - raise NotImplementedError( - "annotate() + distinct(fields) is not implemented.") - if not order_by: - order_by = self.connection.ops.force_no_ordering() - result.append('GROUP BY %s' % ', '.join(grouping)) + if where: + result.append('WHERE %s' % where) + params.extend(w_params) - if having: - result.append('HAVING %s' % having) - params.extend(h_params) + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError('annotate() + distinct(fields) is not implemented.') + if not order_by: + order_by = self.connection.ops.force_no_ordering() + result.append('GROUP BY %s' % ', '.join(grouping)) + + if having: + result.append('HAVING %s' % having) + params.extend(h_params) if order_by: ordering = [] diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 5eea5ad939..16ed92a4d4 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -186,6 +186,11 @@ class Query(object): self.annotation_select_mask = None self._annotation_select_cache = None + # Set combination attributes + self.combinator = None + self.combinator_all = False + self.combined_queries = () + # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. # The _extra attribute is an OrderedDict, lazily created similarly to @@ -303,6 +308,9 @@ class Query(object): # used. obj._annotation_select_cache = None obj.max_depth = self.max_depth + obj.combinator = self.combinator + obj.combinator_all = self.combinator_all + obj.combined_queries = self.combined_queries obj._extra = self._extra.copy() if self._extra is not None else None if self.extra_select_mask is None: obj.extra_select_mask = None diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 5fe783514d..f5f0fbcc8b 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -801,6 +801,61 @@ typically caches its results. If the data in the database might have changed since a ``QuerySet`` was evaluated, you can get updated results for the same query by calling ``all()`` on a previously evaluated ``QuerySet``. +``union()`` +~~~~~~~~~~~ + +.. method:: union(*other_qs, all=False) + +.. versionadded:: 1.11 + +Uses SQL's ``UNION`` operator to combine the results of two or more +``QuerySet``\s. For example: + + >>> qs1.union(qs2, qs3) + +The ``UNION`` operator selects only distinct values by default. To allow +duplicate values, use the ``all=True`` argument. + +``union()``, ``intersection()``, and ``difference()`` return model instances +of the type of the first ``QuerySet`` even if the arguments are ``QuerySet``\s +of other models. Passing different models works as long as the ``SELECT`` list +is the same in all ``QuerySet``\s (at least the types, the names don't matter +as long as the types in the same order). + +In addition, only ``LIMIT``, ``OFFSET``, and ``ORDER BY`` (i.e. slicing and +:meth:`order_by`) are allowed on the resulting ``QuerySet``. Further, databases +place restrictions on what operations are allowed in the combined queries. For +example, most databases don't allow ``LIMIT`` or ``OFFSET`` in the combined +queries. + +``intersection()`` +~~~~~~~~~~~~~~~~~~ + +.. method:: intersection(*other_qs) + +.. versionadded:: 1.11 + +Uses SQL's ``INTERSECT`` operator to return the shared elements of two or more +``QuerySet``\s. For example: + + >>> qs1.itersect(qs2, qs3) + +See :meth:`union` for some restrictions. + +``difference()`` +~~~~~~~~~~~~~~~~ + +.. method:: difference(*other_qs) + +.. versionadded:: 1.11 + +Uses SQL's ``EXCEPT`` operator to keep only elements present in the +``QuerySet`` but not in some other ``QuerySet``\s. For example:: + + >>> qs1.difference(qs2, qs3) + +See :meth:`union` for some restrictions. + ``select_related()`` ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index c8f3e92f21..edbd2b0be3 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -386,6 +386,9 @@ Models * The new ``F`` expression ``bitleftshift()`` and ``bitrightshift()`` methods allow :ref:`bitwise shift operations `. +* Added :meth:`.QuerySet.union`, :meth:`~.QuerySet.intersection`, and + :meth:`~.QuerySet.difference`. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/basic/tests.py b/tests/basic/tests.py index 05a49ac3ab..dd7570c5b9 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -589,6 +589,9 @@ class ManagerTest(SimpleTestCase): '_insert', '_update', 'raw', + 'union', + 'intersection', + 'difference', ] def test_manager_methods(self): diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py new file mode 100644 index 0000000000..a0faab2eb7 --- /dev/null +++ b/tests/queries/test_qs_combinators.py @@ -0,0 +1,111 @@ +from __future__ import unicode_literals + +from django.db.models import F, IntegerField, Value +from django.db.utils import DatabaseError +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.utils.six.moves import range + +from .models import Number, ReservedName + + +@skipUnlessDBFeature('supports_select_union') +class QuerySetSetOperationTests(TestCase): + @classmethod + def setUpTestData(cls): + Number.objects.bulk_create(Number(num=i) for i in range(10)) + + def number_transform(self, value): + return value.num + + def assertNumbersEqual(self, queryset, expected_numbers, ordered=True): + self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered) + + def test_simple_union(self): + qs1 = Number.objects.filter(num__lte=1) + qs2 = Number.objects.filter(num__gte=8) + qs3 = Number.objects.filter(num=5) + self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False) + + @skipUnlessDBFeature('supports_select_intersection') + def test_simple_intersection(self): + qs1 = Number.objects.filter(num__lte=5) + qs2 = Number.objects.filter(num__gte=5) + qs3 = Number.objects.filter(num__gte=4, num__lte=6) + self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False) + + @skipUnlessDBFeature('supports_select_difference') + def test_simple_difference(self): + qs1 = Number.objects.filter(num__lte=5) + qs2 = Number.objects.filter(num__lte=4) + self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False) + + def test_union_distinct(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + self.assertEqual(len(list(qs1.union(qs2, all=True))), 20) + self.assertEqual(len(list(qs1.union(qs2))), 10) + + def test_union_bad_kwarg(self): + qs1 = Number.objects.all() + msg = "union() received an unexpected keyword argument 'bad'" + with self.assertRaisesMessage(TypeError, msg): + self.assertEqual(len(list(qs1.union(qs1, bad=True))), 20) + + def test_limits(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + self.assertEqual(len(list(qs1.union(qs2)[:2])), 2) + + def test_ordering(self): + qs1 = Number.objects.filter(num__lte=1) + qs2 = Number.objects.filter(num__gte=2, num__lte=3) + self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0]) + + @skipUnlessDBFeature('supports_slicing_ordering_in_compound') + def test_ordering_subqueries(self): + qs1 = Number.objects.order_by('num')[:2] + qs2 = Number.objects.order_by('-num')[:2] + self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0]) + + @skipIfDBFeature('supports_slicing_ordering_in_compound') + def test_unsupported_ordering_slicing_raises_db_error(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements' + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2[:10])) + msg = 'ORDER BY not allowed in subqueries of compound statements' + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.order_by('id').union(qs2)) + + @skipIfDBFeature('supports_select_intersection') + def test_unsupported_intersection_raises_db_error(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = 'intersection not supported on this database backend' + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.intersection(qs2)) + + def test_combining_multiple_models(self): + ReservedName.objects.create(name='99 little bugs', order=99) + qs1 = Number.objects.filter(num=1).values_list('num', flat=True) + qs2 = ReservedName.objects.values_list('order') + self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99]) + + def test_order_raises_on_non_selected_column(self): + qs1 = Number.objects.filter().annotate( + annotation=Value(1, IntegerField()), + ).values('annotation', num2=F('num')) + qs2 = Number.objects.filter().values('id', 'num') + # Should not raise + list(qs1.union(qs2).order_by('annotation')) + list(qs1.union(qs2).order_by('num2')) + msg = 'ORDER BY term does not match any column in the result set' + # 'id' is not part of the select + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2).order_by('id')) + # 'num' got realiased to num2 + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2).order_by('num')) + # switched order, now 'exists' again: + list(qs2.union(qs1).order_by('num'))