1
0
mirror of https://github.com/django/django.git synced 2025-06-05 11:39:13 +00:00

Fixed #19964 -- Removed relabel_aliases from some structs

Before there was need to have both .relabel_aliases() and .clone() for
many structs. Now there is only relabeled_clone() for those structs
where alias is the only mutable attribute.
This commit is contained in:
Anssi Kääriäinen 2013-03-02 01:06:56 +02:00
parent 679af4058d
commit d744c550d5
9 changed files with 79 additions and 105 deletions

View File

@ -344,9 +344,9 @@ class Field(object):
if hasattr(value, 'get_compiler'): if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection) value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabel_aliases method, it will need to # If the value has a relabeled_clone method it means the
# be invoked before the final SQL is evaluated # value will be handled later on.
if hasattr(value, 'relabel_aliases'): if hasattr(value, 'relabeled_clone'):
return value return value
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
sql, params = value.as_sql() sql, params = value.as_sql()

View File

@ -153,9 +153,9 @@ class RelatedField(object):
if hasattr(value, 'get_compiler'): if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection) value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabel_aliases method, it will need to # If the value has a relabeled_clone method it means the
# be invoked before the final SQL is evaluated # value will be handled later on.
if hasattr(value, 'relabel_aliases'): if hasattr(value, 'relabeled_clone'):
return value return value
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
sql, params = value.as_sql() sql, params = value.as_sql()

View File

@ -63,14 +63,11 @@ class Aggregate(object):
self.field = tmp self.field = tmp
def clone(self): def relabeled_clone(self, change_map):
# Different aggregates have different init methods, so use copy here clone = copy.copy(self)
# deepcopy is not needed, as self.col is only changing variable.
return copy.copy(self)
def relabel_aliases(self, change_map):
if isinstance(self.col, (list, tuple)): if isinstance(self.col, (list, tuple)):
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
return clone
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL with parameters." "Return the aggregate, rendered as SQL with parameters."

View File

@ -32,10 +32,8 @@ class Date(object):
self.col = col self.col = col
self.lookup_type = lookup_type self.lookup_type = lookup_type
def relabel_aliases(self, change_map): def relabeled_clone(self, change_map):
c = self.col return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
if isinstance(c, (list, tuple)):
self.col = (change_map.get(c[0], c[0]), c[1])
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
if isinstance(self.col, (list, tuple)): if isinstance(self.col, (list, tuple)):
@ -53,10 +51,8 @@ class DateTime(object):
self.lookup_type = lookup_type self.lookup_type = lookup_type
self.tzname = tzname self.tzname = tzname
def relabel_aliases(self, change_map): def relabeled_clone(self, change_map):
c = self.col return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
if isinstance(c, (list, tuple)):
self.col = (change_map.get(c[0], c[0]), c[1])
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
if isinstance(self.col, (list, tuple)): if isinstance(self.col, (list, tuple)):

View File

@ -1,6 +1,7 @@
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
import copy
class SQLEvaluator(object): class SQLEvaluator(object):
def __init__(self, expression, query, allow_joins=True, reuse=None): def __init__(self, expression, query, allow_joins=True, reuse=None):
@ -12,23 +13,23 @@ class SQLEvaluator(object):
self.reuse = reuse self.reuse = reuse
self.expression.prepare(self, query, allow_joins) self.expression.prepare(self, query, allow_joins)
def relabeled_clone(self, change_map):
clone = copy.copy(self)
clone.cols = []
for node, col in self.cols[:]:
if hasattr(col, 'relabeled_clone'):
clone.cols.append((node, col.relabeled_clone(change_map)))
else:
clone.cols.append((node,
(change_map.get(col[0], col[0]), col[1])))
return clone
def prepare(self): def prepare(self):
return self return self
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
return self.expression.evaluate(self, qn, connection) return self.expression.evaluate(self, qn, connection)
def relabel_aliases(self, change_map):
new_cols = []
for node, col in self.cols:
if hasattr(col, "relabel_aliases"):
col.relabel_aliases(change_map)
new_cols.append((node, col))
else:
new_cols.append((node,
(change_map.get(col[0], col[0]), col[1])))
self.cols = new_cols
##################################################### #####################################################
# Vistor methods for initial expression preparation # # Vistor methods for initial expression preparation #
##################################################### #####################################################

View File

@ -294,8 +294,7 @@ class Query(object):
obj.select_for_update_nowait = self.select_for_update_nowait obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related obj.select_related = self.select_related
obj.related_select_cols = [] obj.related_select_cols = []
obj.aggregates = SortedDict((k, v.clone()) obj.aggregates = self.aggregates.copy()
for k, v in self.aggregates.items())
if self.aggregate_select_mask is None: if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None obj.aggregate_select_mask = None
else: else:
@ -559,9 +558,8 @@ class Query(object):
new_col = change_map.get(col[0], col[0]), col[1] new_col = change_map.get(col[0], col[0]), col[1]
self.select.append(SelectInfo(new_col, field)) self.select.append(SelectInfo(new_col, field))
else: else:
item = col.clone() new_col = col.relabeled_clone(change_map)
item.relabel_aliases(change_map) self.select.append(SelectInfo(new_col, field))
self.select.append(SelectInfo(item, field))
if connector == OR: if connector == OR:
# It would be nice to be able to handle this, but the queries don't # It would be nice to be able to handle this, but the queries don't
@ -769,26 +767,6 @@ class Query(object):
The principle for promotion is: any alias which is used (it is in The principle for promotion is: any alias which is used (it is in
alias_usage_counts), is not used by every child of the ORed filter, alias_usage_counts), is not used by every child of the ORed filter,
and isn't pre-existing needs to be promoted to LOUTER join. and isn't pre-existing needs to be promoted to LOUTER join.
Some examples (assume all joins used are nullable):
- existing filter: a__f1=foo
- add filter: b__f1=foo|b__f2=foo
In this case we should not promote either of the joins (using INNER
doesn't remove results). We correctly avoid join promotion, because
a is not used in this branch, and b is used two times.
- add filter a__f1=foo|b__f2=foo
In this case we should promote both a and b, otherwise they will
remove results. We will also correctly do that as both aliases are
used, and in addition both are used only once while there are two
filters.
- existing: a__f1=bar
- add filter: a__f2=foo|b__f2=foo
We will not promote a as it is previously used. If the join results
in null, the existing filter can't succeed.
The above (and some more) are tested in queries.DisjunctionPromotionTests
""" """
for alias, use_count in alias_usage_counts.items(): for alias, use_count in alias_usage_counts.items():
if use_count < num_childs and alias not in aliases_before: if use_count < num_childs and alias not in aliases_before:
@ -807,8 +785,7 @@ class Query(object):
old_alias = col[0] old_alias = col[0]
return (change_map.get(old_alias, old_alias), col[1]) return (change_map.get(old_alias, old_alias), col[1])
else: else:
col.relabel_aliases(change_map) return col.relabeled_clone(change_map)
return col
# 1. Update references in "select" (normal columns plus aliases), # 1. Update references in "select" (normal columns plus aliases),
# "group by", "where" and "having". # "group by", "where" and "having".
self.where.relabel_aliases(change_map) self.where.relabel_aliases(change_map)

View File

@ -34,9 +34,15 @@ class WhereNode(tree.Node):
The class is tied to the Query class that created it (in order to create The class is tied to the Query class that created it (in order to create
the correct SQL). the correct SQL).
The children in this tree are usually either Q-like objects or lists of A child is usually a tuple of:
[table_alias, field_name, db_type, lookup_type, value_annotation, params]. (Constraint(alias, targetcol, field), lookup_type, value)
However, a child could also be any class with as_sql() and relabel_aliases() methods. where value can be either raw Python value, or Query, ExpressionNode or
something else knowing how to turn itself into SQL.
However, a child could also be any class with as_sql() and either
relabeled_clone() method or relabel_aliases() and clone() methods. The
second alternative should be used if the alias is not the only mutable
variable.
""" """
default = AND default = AND
@ -255,30 +261,22 @@ class WhereNode(tree.Node):
lhs = qn(name) lhs = qn(name)
return connection.ops.field_cast_sql(db_type) % lhs return connection.ops.field_cast_sql(db_type) % lhs
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map):
""" """
Relabels the alias values of any children. 'change_map' is a dictionary Relabels the alias values of any children. 'change_map' is a dictionary
mapping old (current) alias values to the new values. mapping old (current) alias values to the new values.
""" """
if not node: for pos, child in enumerate(self.children):
node = self
for pos, child in enumerate(node.children):
if hasattr(child, 'relabel_aliases'): if hasattr(child, 'relabel_aliases'):
# For example another WhereNode
child.relabel_aliases(change_map) child.relabel_aliases(change_map)
elif isinstance(child, tree.Node):
self.relabel_aliases(change_map, child)
elif isinstance(child, (list, tuple)): elif isinstance(child, (list, tuple)):
if isinstance(child[0], (list, tuple)): # tuple starting with Constraint
elt = list(child[0]) child = (child[0].relabeled_clone(change_map),) + child[1:]
if elt[0] in change_map: if hasattr(child[3], 'relabeled_clone'):
elt[0] = change_map[elt[0]] child = (child[0], child[1], child[2]) + (
node.children[pos] = (tuple(elt),) + child[1:] child[3].relabeled_clone(change_map),)
else: self.children[pos] = child
child[0].relabel_aliases(change_map)
# Check if the query value also requires relabelling
if hasattr(child[3], 'relabel_aliases'):
child[3].relabel_aliases(change_map)
def clone(self): def clone(self):
""" """
@ -290,11 +288,10 @@ class WhereNode(tree.Node):
clone = self.__class__._new_instance( clone = self.__class__._new_instance(
children=[], connector=self.connector, negated=self.negated) children=[], connector=self.connector, negated=self.negated)
for child in self.children: for child in self.children:
if isinstance(child, tuple): if hasattr(child, 'clone'):
clone.children.append(
(child[0].clone(), child[1], child[2], child[3]))
else:
clone.children.append(child.clone()) clone.children.append(child.clone())
else:
clone.children.append(child)
return clone return clone
class EmptyWhere(WhereNode): class EmptyWhere(WhereNode):
@ -313,11 +310,6 @@ class EverythingNode(object):
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
return '', [] return '', []
def relabel_aliases(self, change_map, node=None):
return
def clone(self):
return self
class NothingNode(object): class NothingNode(object):
""" """
@ -326,11 +318,6 @@ class NothingNode(object):
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
def relabel_aliases(self, change_map, node=None):
return
def clone(self):
return self
class ExtraWhere(object): class ExtraWhere(object):
def __init__(self, sqls, params): def __init__(self, sqls, params):
@ -341,8 +328,6 @@ class ExtraWhere(object):
sqls = ["(%s)" % sql for sql in self.sqls] sqls = ["(%s)" % sql for sql in self.sqls]
return " AND ".join(sqls), list(self.params or ()) return " AND ".join(sqls), list(self.params or ())
def clone(self):
return self
class Constraint(object): class Constraint(object):
""" """
@ -405,12 +390,11 @@ class Constraint(object):
return (self.alias, self.col, db_type), params return (self.alias, self.col, db_type), params
def relabel_aliases(self, change_map): def relabeled_clone(self, change_map):
if self.alias in change_map: if self.alias not in change_map:
self.alias = change_map[self.alias] return self
else:
def clone(self): new = Empty()
new = Empty() new.__class__ = self.__class__
new.__class__ = self.__class__ new.alias, new.col, new.field = change_map[self.alias], self.col, self.field
new.alias, new.col, new.field = self.alias, self.col, self.field return new
return new

View File

@ -470,3 +470,8 @@ class Paragraph(models.Model):
class Page(models.Model): class Page(models.Model):
text = models.TextField() text = models.TextField()
class MyObject(models.Model):
parent = models.ForeignKey('self', null=True, blank=True, related_name='children')
data = models.CharField(max_length=100)
created_at = models.DateTimeField(auto_now_add=True)

View File

@ -25,7 +25,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job,
JobResponsibilities, BaseA, Identifier, Program, Channel, Page, Paragraph, JobResponsibilities, BaseA, Identifier, Program, Channel, Page, Paragraph,
Chapter, Book) Chapter, Book, MyObject)
class BaseQuerysetTest(TestCase): class BaseQuerysetTest(TestCase):
@ -2661,3 +2661,17 @@ class ManyToManyExcludeTest(TestCase):
self.assertNotIn(b1, q) self.assertNotIn(b1, q)
self.assertIn(b2, q) self.assertIn(b2, q)
self.assertIn(b3, q) self.assertIn(b3, q)
class RelabelCloneTest(TestCase):
def test_ticket_19964(self):
my1 = MyObject.objects.create(data='foo')
my1.parent = my1
my1.save()
my2 = MyObject.objects.create(data='bar', parent=my1)
parents = MyObject.objects.filter(parent=F('id'))
children = MyObject.objects.filter(parent__in=parents).exclude(parent=F('id'))
self.assertEqual(list(parents), [my1])
# Evaluating the children query (which has parents as part of it) does
# not change results for the parents query.
self.assertEqual(list(children), [my2])
self.assertEqual(list(parents), [my1])