mirror of
https://github.com/django/django.git
synced 2024-12-22 09:05:43 +00:00
Fixed #16211 -- Added logical NOT support to F expressions.
This commit is contained in:
parent
c01e76c95c
commit
a320aab512
@ -162,6 +162,9 @@ class Combinable:
|
||||
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
|
||||
)
|
||||
|
||||
def __invert__(self):
|
||||
return NegatedExpression(self)
|
||||
|
||||
|
||||
class BaseExpression:
|
||||
"""Base class for all query expressions."""
|
||||
@ -827,6 +830,9 @@ class F(Combinable):
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
class ResolvedOuterRef(F):
|
||||
"""
|
||||
@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||
return "{}({})".format(self.__class__.__name__, self.expression)
|
||||
|
||||
|
||||
class NegatedExpression(ExpressionWrapper):
|
||||
"""The logical negation of a conditional expression."""
|
||||
|
||||
def __init__(self, expression):
|
||||
super().__init__(expression, output_field=fields.BooleanField())
|
||||
|
||||
def __invert__(self):
|
||||
return self.expression.copy()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
try:
|
||||
sql, params = super().as_sql(compiler, connection)
|
||||
except EmptyResultSet:
|
||||
features = compiler.connection.features
|
||||
if not features.supports_boolean_expr_in_select_clause:
|
||||
return "1=1", ()
|
||||
return compiler.compile(Value(True))
|
||||
ops = compiler.connection.ops
|
||||
# Some database backends (e.g. Oracle) don't allow EXISTS() and filters
|
||||
# to be compared to another expression unless they're wrapped in a CASE
|
||||
# WHEN.
|
||||
if not ops.conditional_expression_supported_in_where_clause(self.expression):
|
||||
return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
|
||||
return f"NOT {sql}", params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
resolved = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
if not getattr(resolved.expression, "conditional", False):
|
||||
raise TypeError("Cannot negate non-conditional expressions.")
|
||||
return resolved
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap boolean expressions with a CASE WHEN expression if a database
|
||||
# backend (e.g. Oracle) doesn't support boolean expression in SELECT or
|
||||
# GROUP BY list.
|
||||
expression_supported_in_where_clause = (
|
||||
compiler.connection.ops.conditional_expression_supported_in_where_clause
|
||||
)
|
||||
if (
|
||||
not compiler.connection.features.supports_boolean_expr_in_select_clause
|
||||
# Avoid double wrapping.
|
||||
and expression_supported_in_where_clause(self.expression)
|
||||
):
|
||||
sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
|
||||
return sql, params
|
||||
|
||||
|
||||
@deconstructible(path="django.db.models.When")
|
||||
class When(Expression):
|
||||
template = "WHEN %(condition)s THEN %(result)s"
|
||||
@ -1486,34 +1543,10 @@ class Exists(Subquery):
|
||||
template = "EXISTS(%(subquery)s)"
|
||||
output_field = fields.BooleanField()
|
||||
|
||||
def __init__(self, queryset, negated=False, **kwargs):
|
||||
self.negated = negated
|
||||
def __init__(self, queryset, **kwargs):
|
||||
super().__init__(queryset, **kwargs)
|
||||
self.query = self.query.exists()
|
||||
|
||||
def __invert__(self):
|
||||
clone = self.copy()
|
||||
clone.negated = not self.negated
|
||||
return clone
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
try:
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
**extra_context,
|
||||
)
|
||||
except EmptyResultSet:
|
||||
if self.negated:
|
||||
features = compiler.connection.features
|
||||
if not features.supports_boolean_expr_in_select_clause:
|
||||
return "1=1", ()
|
||||
return compiler.compile(Value(True))
|
||||
raise
|
||||
if self.negated:
|
||||
sql = "NOT {}".format(sql)
|
||||
return sql, params
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap EXISTS() with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
|
@ -255,6 +255,19 @@ is null) after companies that have been contacted::
|
||||
from django.db.models import F
|
||||
Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
|
||||
|
||||
Using ``F()`` with logical operations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. versionadded:: 4.2
|
||||
|
||||
``F()`` expressions that output ``BooleanField`` can be logically negated with
|
||||
the inversion operator ``~F()``. For example, to swap the activation status of
|
||||
companies::
|
||||
|
||||
from django.db.models import F
|
||||
|
||||
Company.objects.update(is_active=~F('is_active'))
|
||||
|
||||
.. _func-expressions:
|
||||
|
||||
``Func()`` expressions
|
||||
|
@ -236,6 +236,9 @@ Models
|
||||
* :class:`~django.db.models.functions.Now` now supports microsecond precision
|
||||
on MySQL and millisecond precision on SQLite.
|
||||
|
||||
* :class:`F() <django.db.models.F>` expressions that output ``BooleanField``
|
||||
can now be negated using ``~F()`` (inversion operator).
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -345,6 +348,9 @@ Miscellaneous
|
||||
* The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
|
||||
0.2.3.
|
||||
|
||||
* The undocumented ``negated`` parameter of the
|
||||
:class:`~django.db.models.Exists` expression is removed.
|
||||
|
||||
.. _deprecated-features-4.2:
|
||||
|
||||
Features deprecated in 4.2
|
||||
|
@ -48,6 +48,7 @@ from django.db.models.expressions import (
|
||||
Col,
|
||||
Combinable,
|
||||
CombinedExpression,
|
||||
NegatedExpression,
|
||||
RawSQL,
|
||||
Ref,
|
||||
)
|
||||
@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
|
||||
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
|
||||
|
||||
|
||||
class NegatedExpressionTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
|
||||
cls.eu_company = Company.objects.create(
|
||||
name="Example Inc.",
|
||||
num_employees=2300,
|
||||
num_chairs=5,
|
||||
ceo=ceo,
|
||||
based_in_eu=True,
|
||||
)
|
||||
cls.non_eu_company = Company.objects.create(
|
||||
name="Foobar Ltd.",
|
||||
num_employees=3,
|
||||
num_chairs=4,
|
||||
ceo=ceo,
|
||||
based_in_eu=False,
|
||||
)
|
||||
|
||||
def test_invert(self):
|
||||
f = F("field")
|
||||
self.assertEqual(~f, NegatedExpression(f))
|
||||
self.assertIsNot(~~f, f)
|
||||
self.assertEqual(~~f, f)
|
||||
|
||||
def test_filter(self):
|
||||
self.assertSequenceEqual(
|
||||
Company.objects.filter(~F("based_in_eu")),
|
||||
[self.non_eu_company],
|
||||
)
|
||||
|
||||
qs = Company.objects.annotate(eu_required=~Value(False))
|
||||
self.assertSequenceEqual(
|
||||
qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
|
||||
[self.eu_company],
|
||||
)
|
||||
self.assertSequenceEqual(
|
||||
qs.filter(based_in_eu=~~F("eu_required")),
|
||||
[self.eu_company],
|
||||
)
|
||||
self.assertSequenceEqual(
|
||||
qs.filter(based_in_eu=~F("eu_required")),
|
||||
[self.non_eu_company],
|
||||
)
|
||||
self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
|
||||
|
||||
def test_values(self):
|
||||
self.assertSequenceEqual(
|
||||
Company.objects.annotate(negated=~F("based_in_eu"))
|
||||
.values_list("name", "negated")
|
||||
.order_by("name"),
|
||||
[("Example Inc.", False), ("Foobar Ltd.", True)],
|
||||
)
|
||||
|
||||
|
||||
class OrderByTests(SimpleTestCase):
|
||||
def test_equal(self):
|
||||
self.assertEqual(
|
||||
|
@ -8,7 +8,7 @@ from django.db.models import (
|
||||
Q,
|
||||
Value,
|
||||
)
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.expressions import NegatedExpression, RawSQL
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.models.sql.where import NothingNode
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
|
||||
]
|
||||
for q in tests:
|
||||
with self.subTest(q=q):
|
||||
self.assertIs(q.negated, True)
|
||||
self.assertIsInstance(q, NegatedExpression)
|
||||
|
||||
def test_deconstruct(self):
|
||||
q = Q(price__gt=F("discounted_price"))
|
||||
|
@ -10,6 +10,7 @@ class DataPoint(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
value = models.CharField(max_length=20)
|
||||
another_value = models.CharField(max_length=20, blank=True)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
|
||||
class RelatedPoint(models.Model):
|
||||
|
@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import IntegrityError, connection, transaction
|
||||
from django.db.models import CharField, Count, F, IntegerField, Max
|
||||
from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
|
||||
from django.db.models.functions import Abs, Concat, Lower
|
||||
from django.test import TestCase
|
||||
from django.test.utils import register_lookup
|
||||
@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
|
||||
def setUpTestData(cls):
|
||||
cls.d0 = DataPoint.objects.create(name="d0", value="apple")
|
||||
cls.d2 = DataPoint.objects.create(name="d2", value="banana")
|
||||
cls.d3 = DataPoint.objects.create(name="d3", value="banana")
|
||||
cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
|
||||
cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
|
||||
|
||||
def test_update(self):
|
||||
@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
|
||||
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
|
||||
self.assertEqual(Bar.objects.get().x, 3)
|
||||
|
||||
def test_update_negated_f(self):
|
||||
DataPoint.objects.update(is_active=~F("is_active"))
|
||||
self.assertCountEqual(
|
||||
DataPoint.objects.values_list("name", "is_active"),
|
||||
[("d0", False), ("d2", False), ("d3", True)],
|
||||
)
|
||||
DataPoint.objects.update(is_active=~F("is_active"))
|
||||
self.assertCountEqual(
|
||||
DataPoint.objects.values_list("name", "is_active"),
|
||||
[("d0", True), ("d2", True), ("d3", False)],
|
||||
)
|
||||
|
||||
def test_update_negated_f_conditional_annotation(self):
|
||||
DataPoint.objects.annotate(
|
||||
is_d2=Case(When(name="d2", then=True), default=False)
|
||||
).update(is_active=~F("is_d2"))
|
||||
self.assertCountEqual(
|
||||
DataPoint.objects.values_list("name", "is_active"),
|
||||
[("d0", True), ("d2", False), ("d3", True)],
|
||||
)
|
||||
|
||||
def test_updating_non_conditional_field(self):
|
||||
msg = "Cannot negate non-conditional expressions."
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
DataPoint.objects.update(is_active=~F("name"))
|
||||
|
||||
|
||||
@unittest.skipUnless(
|
||||
connection.vendor == "mysql",
|
||||
|
Loading…
Reference in New Issue
Block a user