1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #29865 -- Added logical XOR support for Q() and querysets.

This commit is contained in:
Ryan Heard 2021-07-02 15:09:13 -05:00 committed by Mariusz Felisiak
parent 795da6306a
commit c6b4d62fa2
19 changed files with 311 additions and 18 deletions

View File

@ -833,6 +833,7 @@ answer newbie questions, and generally made Django that much better:
Russell Keith-Magee <russell@keith-magee.com>
Russ Webber
Ryan Hall <ryanhall989@gmail.com>
Ryan Heard <ryanwheard@gmail.com>
ryankanno
Ryan Kelly <ryan@rfk.id.au>
Ryan Niemeyer <https://profiles.google.com/ryan.niemeyer/about>

View File

@ -325,6 +325,9 @@ class BaseDatabaseFeatures:
# Does the backend support non-deterministic collations?
supports_non_deterministic_collations = True
# Does the backend support the logical XOR operator?
supports_logical_xor = False
# Collation names for use by the Django test suite.
test_collations = {
"ci": None, # Case-insensitive.

View File

@ -47,6 +47,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
supports_logical_xor = True
@cached_property
def minimum_database_version(self):

View File

@ -94,7 +94,7 @@ class Combinable:
if getattr(self, "conditional", False) and getattr(other, "conditional", False):
return Q(self) & Q(other)
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def bitand(self, other):
@ -106,6 +106,13 @@ class Combinable:
def bitrightshift(self, other):
return self._combine(other, self.BITRIGHTSHIFT, False)
def __xor__(self, other):
if getattr(self, "conditional", False) and getattr(other, "conditional", False):
return Q(self) ^ Q(other)
raise NotImplementedError(
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def bitxor(self, other):
return self._combine(other, self.BITXOR, False)
@ -113,7 +120,7 @@ class Combinable:
if getattr(self, "conditional", False) and getattr(other, "conditional", False):
return Q(self) | Q(other)
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def bitor(self, other):
@ -139,12 +146,17 @@ class Combinable:
def __rand__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def __ror__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def __rxor__(self, other):
raise NotImplementedError(
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)

View File

@ -396,6 +396,25 @@ class QuerySet:
combined.query.combine(other.query, sql.OR)
return combined
def __xor__(self, other):
self._check_operator_queryset(other, "^")
self._merge_sanity_check(other)
if isinstance(self, EmptyQuerySet):
return other
if isinstance(other, EmptyQuerySet):
return self
query = (
self
if self.query.can_filter()
else self.model._base_manager.filter(pk__in=self.values("pk"))
)
combined = query._chain()
combined._merge_known_related_objects(other)
if not other.query.can_filter():
other = other.model._base_manager.filter(pk__in=other.values("pk"))
combined.query.combine(other.query, sql.XOR)
return combined
####################################
# METHODS THAT DO DATABASE QUERIES #
####################################

View File

@ -38,6 +38,7 @@ class Q(tree.Node):
# Connection types
AND = "AND"
OR = "OR"
XOR = "XOR"
default = AND
conditional = True
@ -70,6 +71,9 @@ class Q(tree.Node):
def __and__(self, other):
return self._combine(other, self.AND)
def __xor__(self, other):
return self._combine(other, self.XOR)
def __invert__(self):
obj = type(self)()
obj.add(self, self.AND)

View File

@ -1,6 +1,6 @@
from django.db.models.sql.query import * # NOQA
from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import * # NOQA
from django.db.models.sql.where import AND, OR
from django.db.models.sql.where import AND, OR, XOR
__all__ = ["Query", "AND", "OR"]
__all__ = ["Query", "AND", "OR", "XOR"]

View File

@ -1,14 +1,19 @@
"""
Code to manage the creation and SQL rendering of 'where' constraints.
"""
import operator
from functools import reduce
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact
from django.utils import tree
from django.utils.functional import cached_property
# Connection types
AND = "AND"
OR = "OR"
XOR = "XOR"
class WhereNode(tree.Node):
@ -39,10 +44,12 @@ class WhereNode(tree.Node):
if not self.contains_aggregate:
return self, None
in_negated = negated ^ self.negated
# If the effective connector is OR and this node contains an aggregate,
# then we need to push the whole branch to HAVING clause.
may_need_split = (in_negated and self.connector == AND) or (
not in_negated and self.connector == OR
# If the effective connector is OR or XOR and this node contains an
# aggregate, then we need to push the whole branch to HAVING clause.
may_need_split = (
(in_negated and self.connector == AND)
or (not in_negated and self.connector == OR)
or self.connector == XOR
)
if may_need_split and self.contains_aggregate:
return None, self
@ -85,6 +92,21 @@ class WhereNode(tree.Node):
else:
full_needed, empty_needed = 1, len(self.children)
if self.connector == XOR and not connection.features.supports_logical_xor:
# Convert if the database doesn't support XOR:
# a XOR b XOR c XOR ...
# to:
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1
lhs = self.__class__(self.children, OR)
rhs_sum = reduce(
operator.add,
(Case(When(c, then=1), default=0) for c in self.children),
)
rhs = Exact(1, rhs_sum)
return self.__class__([lhs, rhs], AND, self.negated).as_sql(
compiler, connection
)
for child in self.children:
try:
sql, params = compiler.compile(child)

View File

@ -1903,6 +1903,40 @@ SQL equivalent:
``|`` is not a commutative operation, as different (though equivalent) queries
may be generated.
XOR (``^``)
~~~~~~~~~~~
.. versionadded:: 4.1
Combines two ``QuerySet``\s using the SQL ``XOR`` operator.
The following are equivalent::
Model.objects.filter(x=1) ^ Model.objects.filter(y=2)
from django.db.models import Q
Model.objects.filter(Q(x=1) ^ Q(y=2))
SQL equivalent:
.. code-block:: sql
SELECT ... WHERE x=1 XOR y=2
.. note::
``XOR`` is natively supported on MariaDB and MySQL. On other databases,
``x ^ y ^ ... ^ z`` is converted to an equivalent:
.. code-block:: sql
(x OR y OR ... OR z) AND
1=(
(CASE WHEN x THEN 1 ELSE 0 END) +
(CASE WHEN y THEN 1 ELSE 0 END) +
...
(CASE WHEN z THEN 1 ELSE 0 END) +
)
Methods that do not return ``QuerySet``\s
-----------------------------------------
@ -3751,8 +3785,12 @@ A ``Q()`` object represents an SQL condition that can be used in
database-related operations. It's similar to how an
:class:`F() <django.db.models.F>` object represents the value of a model field
or annotation. They make it possible to define and reuse conditions, and
combine them using operators such as ``|`` (``OR``) and ``&`` (``AND``). See
:ref:`complex-lookups-with-q`.
combine them using operators such as ``|`` (``OR``), ``&`` (``AND``), and ``^``
(``XOR``). See :ref:`complex-lookups-with-q`.
.. versionchanged:: 4.1
Support for the ``^`` (``XOR``) operator was added.
``Prefetch()`` objects
----------------------

View File

@ -273,6 +273,11 @@ Models
as the ``chunk_size`` argument is provided. In older versions, no prefetching
was done.
* :class:`~django.db.models.Q` objects and querysets can now be combined using
``^`` as the exclusive or (``XOR``) operator. ``XOR`` is natively supported
on MariaDB and MySQL. For databases that do not support ``XOR``, the query
will be converted to an equivalent using ``AND``, ``OR``, and ``NOT``.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1111,8 +1111,8 @@ For example, this ``Q`` object encapsulates a single ``LIKE`` query::
from django.db.models import Q
Q(question__startswith='What')
``Q`` objects can be combined using the ``&`` and ``|`` operators. When an
operator is used on two ``Q`` objects, it yields a new ``Q`` object.
``Q`` objects can be combined using the ``&``, ``|``, and ``^`` operators. When
an operator is used on two ``Q`` objects, it yields a new ``Q`` object.
For example, this statement yields a single ``Q`` object that represents the
"OR" of two ``"question__startswith"`` queries::
@ -1124,9 +1124,10 @@ This is equivalent to the following SQL ``WHERE`` clause::
WHERE question LIKE 'Who%' OR question LIKE 'What%'
You can compose statements of arbitrary complexity by combining ``Q`` objects
with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q``
objects can be negated using the ``~`` operator, allowing for combined lookups
that combine both a normal query and a negated (``NOT``) query::
with the ``&``, ``|``, and ``^`` operators and use parenthetical grouping.
Also, ``Q`` objects can be negated using the ``~`` operator, allowing for
combined lookups that combine both a normal query and a negated (``NOT``)
query::
Q(question__startswith='Who') | ~Q(pub_date__year=2005)
@ -1175,6 +1176,10 @@ precede the definition of any keyword arguments. For example::
The :source:`OR lookups examples <tests/or_lookups/tests.py>` in Django's
unit tests show some possible uses of ``Q``.
.. versionchanged:: 4.1
Support for the ``^`` (``XOR``) operator was added.
Comparing objects
=================

View File

@ -1704,6 +1704,28 @@ class AggregationTests(TestCase):
attrgetter("pk"),
)
def test_filter_aggregates_xor_connector(self):
q1 = Q(price__gt=50)
q2 = Q(authors__count__gt=1)
query = Book.objects.annotate(Count("authors")).filter(q1 ^ q2).order_by("pk")
self.assertQuerysetEqual(
query,
[self.b1.pk, self.b4.pk, self.b6.pk],
attrgetter("pk"),
)
def test_filter_aggregates_negated_xor_connector(self):
q1 = Q(price__gt=50)
q2 = Q(authors__count__gt=1)
query = (
Book.objects.annotate(Count("authors")).filter(~(q1 ^ q2)).order_by("pk")
)
self.assertQuerysetEqual(
query,
[self.b2.pk, self.b3.pk, self.b5.pk],
attrgetter("pk"),
)
def test_ticket_11293_q_immutable(self):
"""
Splitting a q object to parts for where/having doesn't alter

View File

@ -2339,7 +2339,9 @@ class ReprTests(SimpleTestCase):
class CombinableTests(SimpleTestCase):
bitwise_msg = "Use .bitand() and .bitor() for bitwise logical operations."
bitwise_msg = (
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def test_negation(self):
c = Combinable()
@ -2353,6 +2355,10 @@ class CombinableTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
Combinable() | Combinable()
def test_xor(self):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
Combinable() ^ Combinable()
def test_reversed_and(self):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
object() & Combinable()
@ -2361,6 +2367,10 @@ class CombinableTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
object() | Combinable()
def test_reversed_xor(self):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
object() ^ Combinable()
class CombinedExpressionTests(SimpleTestCase):
def test_resolve_output_field(self):

View File

@ -27,6 +27,15 @@ class QTests(SimpleTestCase):
self.assertEqual(q | Q(), q)
self.assertEqual(Q() | q, q)
def test_combine_xor_empty(self):
q = Q(x=1)
self.assertEqual(q ^ Q(), q)
self.assertEqual(Q() ^ q, q)
q = Q(x__in={}.keys())
self.assertEqual(q ^ Q(), q)
self.assertEqual(Q() ^ q, q)
def test_combine_empty_copy(self):
base_q = Q(x=1)
tests = [
@ -34,6 +43,8 @@ class QTests(SimpleTestCase):
Q() | base_q,
base_q & Q(),
Q() & base_q,
base_q ^ Q(),
Q() ^ base_q,
]
for i, q in enumerate(tests):
with self.subTest(i=i):
@ -43,6 +54,9 @@ class QTests(SimpleTestCase):
def test_combine_or_both_empty(self):
self.assertEqual(Q() | Q(), Q())
def test_combine_xor_both_empty(self):
self.assertEqual(Q() ^ Q(), Q())
def test_combine_not_q_object(self):
obj = object()
q = Q(x=1)
@ -50,12 +64,15 @@ class QTests(SimpleTestCase):
q | obj
with self.assertRaisesMessage(TypeError, str(obj)):
q & obj
with self.assertRaisesMessage(TypeError, str(obj)):
q ^ obj
def test_combine_negated_boolean_expression(self):
tagged = Tag.objects.filter(category=OuterRef("pk"))
tests = [
Q() & ~Exists(tagged),
Q() | ~Exists(tagged),
Q() ^ ~Exists(tagged),
]
for q in tests:
with self.subTest(q=q):
@ -88,6 +105,20 @@ class QTests(SimpleTestCase):
)
self.assertEqual(kwargs, {"_connector": "OR"})
def test_deconstruct_xor(self):
q1 = Q(price__gt=F("discounted_price"))
q2 = Q(price=F("discounted_price"))
q = q1 ^ q2
path, args, kwargs = q.deconstruct()
self.assertEqual(
args,
(
("price__gt", F("discounted_price")),
("price", F("discounted_price")),
),
)
self.assertEqual(kwargs, {"_connector": "XOR"})
def test_deconstruct_and(self):
q1 = Q(price__gt=F("discounted_price"))
q2 = Q(price=F("discounted_price"))
@ -144,6 +175,13 @@ class QTests(SimpleTestCase):
path, args, kwargs = q.deconstruct()
self.assertEqual(Q(*args, **kwargs), q)
def test_reconstruct_xor(self):
q1 = Q(price__gt=F("discounted_price"))
q2 = Q(price=F("discounted_price"))
q = q1 ^ q2
path, args, kwargs = q.deconstruct()
self.assertEqual(Q(*args, **kwargs), q)
def test_reconstruct_and(self):
q1 = Q(price__gt=F("discounted_price"))
q2 = Q(price=F("discounted_price"))

View File

@ -526,6 +526,7 @@ class QuerySetSetOperationTests(TestCase):
operators = [
("|", operator.or_),
("&", operator.and_),
("^", operator.xor),
]
for combinator in combinators:
combined_qs = getattr(qs, combinator)(qs)

View File

@ -1883,6 +1883,10 @@ class Queries5Tests(TestCase):
Note.objects.exclude(~Q() & ~Q()),
[self.n1, self.n2],
)
self.assertSequenceEqual(
Note.objects.exclude(~Q() ^ ~Q()),
[self.n1, self.n2],
)
def test_extra_select_literal_percent_s(self):
# Allow %%s to escape select clauses
@ -2129,6 +2133,15 @@ class Queries6Tests(TestCase):
sql = captured_queries[0]["sql"]
self.assertIn("AS %s" % connection.ops.quote_name("col1"), sql)
def test_xor_subquery(self):
self.assertSequenceEqual(
Tag.objects.filter(
Exists(Tag.objects.filter(id=OuterRef("id"), name="t3"))
^ Exists(Tag.objects.filter(id=OuterRef("id"), parent=self.t1))
),
[self.t2],
)
class RawQueriesTests(TestCase):
@classmethod
@ -2432,6 +2445,30 @@ class QuerySetBitwiseOperationTests(TestCase):
qs2 = Classroom.objects.filter(has_blackboard=True).order_by("-name")[:1]
self.assertCountEqual(qs1 | qs2, [self.room_3, self.room_4])
@skipUnlessDBFeature("allow_sliced_subqueries_with_in")
def test_xor_with_rhs_slice(self):
qs1 = Classroom.objects.filter(has_blackboard=True)
qs2 = Classroom.objects.filter(has_blackboard=False)[:1]
self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2, self.room_3])
@skipUnlessDBFeature("allow_sliced_subqueries_with_in")
def test_xor_with_lhs_slice(self):
qs1 = Classroom.objects.filter(has_blackboard=True)[:1]
qs2 = Classroom.objects.filter(has_blackboard=False)
self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2, self.room_4])
@skipUnlessDBFeature("allow_sliced_subqueries_with_in")
def test_xor_with_both_slice(self):
qs1 = Classroom.objects.filter(has_blackboard=False)[:1]
qs2 = Classroom.objects.filter(has_blackboard=True)[:1]
self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2])
@skipUnlessDBFeature("allow_sliced_subqueries_with_in")
def test_xor_with_both_slice_and_ordering(self):
qs1 = Classroom.objects.filter(has_blackboard=False).order_by("-pk")[:1]
qs2 = Classroom.objects.filter(has_blackboard=True).order_by("-name")[:1]
self.assertCountEqual(qs1 ^ qs2, [self.room_3, self.room_4])
def test_subquery_aliases(self):
combined = School.objects.filter(pk__isnull=False) & School.objects.filter(
Exists(

View File

View File

@ -0,0 +1,8 @@
from django.db import models
class Number(models.Model):
num = models.IntegerField()
def __str__(self):
return str(self.num)

View File

@ -0,0 +1,67 @@
from django.db.models import Q
from django.test import TestCase
from .models import Number
class XorLookupsTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.numbers = [Number.objects.create(num=i) for i in range(10)]
def test_filter(self):
self.assertCountEqual(
Number.objects.filter(num__lte=7) ^ Number.objects.filter(num__gte=3),
self.numbers[:3] + self.numbers[8:],
)
self.assertCountEqual(
Number.objects.filter(Q(num__lte=7) ^ Q(num__gte=3)),
self.numbers[:3] + self.numbers[8:],
)
def test_filter_negated(self):
self.assertCountEqual(
Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3)),
self.numbers[:3] + self.numbers[8:],
)
self.assertCountEqual(
Number.objects.filter(~Q(num__gt=7) ^ ~Q(num__lt=3)),
self.numbers[:3] + self.numbers[8:],
)
self.assertCountEqual(
Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3) ^ Q(num__lte=1)),
[self.numbers[2]] + self.numbers[8:],
)
self.assertCountEqual(
Number.objects.filter(~(Q(num__lte=7) ^ ~Q(num__lt=3) ^ Q(num__lte=1))),
self.numbers[:2] + self.numbers[3:8],
)
def test_exclude(self):
self.assertCountEqual(
Number.objects.exclude(Q(num__lte=7) ^ Q(num__gte=3)),
self.numbers[3:8],
)
def test_stages(self):
numbers = Number.objects.all()
self.assertSequenceEqual(
numbers.filter(num__gte=0) ^ numbers.filter(num__lte=11),
[],
)
self.assertSequenceEqual(
numbers.filter(num__gt=0) ^ numbers.filter(num__lt=11),
[self.numbers[0]],
)
def test_pk_q(self):
self.assertCountEqual(
Number.objects.filter(Q(pk=self.numbers[0].pk) ^ Q(pk=self.numbers[1].pk)),
self.numbers[:2],
)
def test_empty_in(self):
self.assertCountEqual(
Number.objects.filter(Q(pk__in=[]) ^ Q(num__gte=5)),
self.numbers[5:],
)