1
0
mirror of https://github.com/django/django.git synced 2024-12-25 02:26:12 +00:00

Refactored the empty/full result logic in WhereNode.as_sql()

Made sure the WhereNode.as_sql() handles various EmptyResultSet and
FullResultSet conditions correctly. Also, got rid of the FullResultSet
exception class. It is now represented by '', [] return value in the
as_sql() methods.
This commit is contained in:
Anssi Kääriäinen 2012-05-26 05:55:33 +03:00
parent 2b9fb2e644
commit bd283aa844
3 changed files with 135 additions and 39 deletions

View File

@ -6,9 +6,6 @@ the SQL domain.
class EmptyResultSet(Exception):
pass
class FullResultSet(Exception):
pass
class MultiJoin(Exception):
"""
Used by join construction code to indicate the point at which a

View File

@ -10,7 +10,7 @@ from itertools import repeat
from django.utils import tree
from django.db.models.fields import Field
from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.aggregates import Aggregate
# Connection types
@ -75,17 +75,21 @@ class WhereNode(tree.Node):
def as_sql(self, qn, connection):
"""
Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty.
If 'node' is provided, that is the root of the SQL generation
(generally not needed except by the internal implementation for
recursion).
substituted in. Returns '', [] if this node matches everything,
None, [] if this node is empty, and raises EmptyResultSet if this
node can't match anything.
"""
if not self.children:
return None, []
# Note that the logic here is made slightly more complex than
# necessary because there are two kind of empty nodes: Nodes
# containing 0 children, and nodes that are known to match everything.
# A match-everything node is different than empty node (which also
# technically matches everything) for backwards compatibility reasons.
# Refs #5261.
result = []
result_params = []
empty = True
everything_childs, nothing_childs = 0, 0
non_empty_childs = len(self.children)
for child in self.children:
try:
if hasattr(child, 'as_sql'):
@ -93,39 +97,48 @@ class WhereNode(tree.Node):
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection)
except EmptyResultSet:
if self.connector == AND and not self.negated:
# We can bail out early in this particular case (only).
raise
elif self.negated:
empty = False
continue
except FullResultSet:
if self.connector == OR:
if self.negated:
empty = True
break
# We match everything. No need for any constraints.
return '', []
if self.negated:
empty = True
continue
empty = False
nothing_childs += 1
else:
if sql:
result.append(sql)
result_params.extend(params)
if empty:
else:
if sql is None:
# Skip empty childs totally.
non_empty_childs -= 1
continue
everything_childs += 1
# Check if this node matches nothing or everything.
# First check the amount of full nodes and empty nodes
# to make this node empty/full.
if self.connector == AND:
full_needed, empty_needed = non_empty_childs, 1
else:
full_needed, empty_needed = 1, non_empty_childs
# Now, check if this node is full/empty using the
# counts.
if empty_needed - nothing_childs <= 0:
if self.negated:
return '', []
else:
raise EmptyResultSet
if full_needed - everything_childs <= 0:
if self.negated:
raise EmptyResultSet
else:
return '', []
if non_empty_childs == 0:
# All the child nodes were empty, so this one is empty, too.
return None, []
conn = ' %s ' % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
sql_string = 'NOT (%s)' % sql_string
elif len(self.children) != 1:
if len(result) > 1:
sql_string = '(%s)' % sql_string
if self.negated:
sql_string = 'NOT %s' % sql_string
return sql_string, result_params
def make_atom(self, child, qn, connection):
@ -261,7 +274,7 @@ class EverythingNode(object):
"""
def as_sql(self, qn=None, connection=None):
raise FullResultSet
return '', []
def relabel_aliases(self, change_map, node=None):
return

View File

@ -10,6 +10,8 @@ from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS
from django.db.models import Count
from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet
from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode
from django.db.models.sql.datastructures import EmptyResultSet
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import str_prefix
from django.utils import unittest
@ -1316,10 +1318,23 @@ class Queries5Tests(TestCase):
)
def test_ticket5261(self):
# Test different empty excludes.
self.assertQuerysetEqual(
Note.objects.exclude(Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.filter(~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.filter(~Q()|~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.exclude(~Q()&~Q()),
['<Note: n1>', '<Note: n2>']
)
class SelectRelatedTests(TestCase):
@ -2020,3 +2035,74 @@ class ProxyQueryCleanupTest(TestCase):
self.assertEqual(qs.count(), 1)
str(qs.query)
self.assertEqual(qs.count(), 1)
class WhereNodeTest(TestCase):
class DummyNode(object):
def as_sql(self, qn, connection):
return 'dummy', []
def test_empty_full_handling_conjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('(dummy AND dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('(dummy OR dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT dummy', []))
def test_empty_nodes(self):
qn = connection.ops.quote_name
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.connector = 'OR'
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w = WhereNode(children=[empty_w, NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)