diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index f9f913bc1b..c0d661eaec 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -344,9 +344,9 @@ class Field(object): if hasattr(value, 'get_compiler'): value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): - # If the value has a relabel_aliases method, it will need to - # be invoked before the final SQL is evaluated - if hasattr(value, 'relabel_aliases'): + # If the value has a relabeled_clone method it means the + # value will be handled later on. + if hasattr(value, 'relabeled_clone'): return value if hasattr(value, 'as_sql'): sql, params = value.as_sql() diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 3b47eb86bb..ee1361779a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -153,9 +153,9 @@ class RelatedField(object): if hasattr(value, 'get_compiler'): value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): - # If the value has a relabel_aliases method, it will need to - # be invoked before the final SQL is evaluated - if hasattr(value, 'relabel_aliases'): + # If the value has a relabeled_clone method it means the + # value will be handled later on. + if hasattr(value, 'relabeled_clone'): return value if hasattr(value, 'as_sql'): sql, params = value.as_sql() diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 3c8720210b..1b65847b7f 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -63,14 +63,11 @@ class Aggregate(object): self.field = tmp - def clone(self): - # Different aggregates have different init methods, so use copy here - # deepcopy is not needed, as self.col is only changing variable. - return copy.copy(self) - - def relabel_aliases(self, change_map): + def relabeled_clone(self, change_map): + clone = copy.copy(self) if isinstance(self.col, (list, tuple)): - self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) + clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) + return clone def as_sql(self, qn, connection): "Return the aggregate, rendered as SQL with parameters." diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 4bc9e6ed34..daaabbe6da 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -32,10 +32,8 @@ class Date(object): self.col = col self.lookup_type = lookup_type - def relabel_aliases(self, change_map): - c = self.col - if isinstance(c, (list, tuple)): - self.col = (change_map.get(c[0], c[0]), c[1]) + def relabeled_clone(self, change_map): + return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) def as_sql(self, qn, connection): if isinstance(self.col, (list, tuple)): @@ -53,10 +51,8 @@ class DateTime(object): self.lookup_type = lookup_type self.tzname = tzname - def relabel_aliases(self, change_map): - c = self.col - if isinstance(c, (list, tuple)): - self.col = (change_map.get(c[0], c[0]), c[1]) + def relabeled_clone(self, change_map): + return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) def as_sql(self, qn, connection): if isinstance(self.col, (list, tuple)): diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 45b9cb202b..55ae655cb0 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -1,6 +1,7 @@ from django.core.exceptions import FieldError from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import FieldDoesNotExist +import copy class SQLEvaluator(object): def __init__(self, expression, query, allow_joins=True, reuse=None): @@ -12,23 +13,23 @@ class SQLEvaluator(object): self.reuse = reuse self.expression.prepare(self, query, allow_joins) + def relabeled_clone(self, change_map): + clone = copy.copy(self) + clone.cols = [] + for node, col in self.cols[:]: + if hasattr(col, 'relabeled_clone'): + clone.cols.append((node, col.relabeled_clone(change_map))) + else: + clone.cols.append((node, + (change_map.get(col[0], col[0]), col[1]))) + return clone + def prepare(self): return self def as_sql(self, qn, connection): return self.expression.evaluate(self, qn, connection) - def relabel_aliases(self, change_map): - new_cols = [] - for node, col in self.cols: - if hasattr(col, "relabel_aliases"): - col.relabel_aliases(change_map) - new_cols.append((node, col)) - else: - new_cols.append((node, - (change_map.get(col[0], col[0]), col[1]))) - self.cols = new_cols - ##################################################### # Vistor methods for initial expression preparation # ##################################################### diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9af7544db3..fa583f6120 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -294,8 +294,7 @@ class Query(object): obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related obj.related_select_cols = [] - obj.aggregates = SortedDict((k, v.clone()) - for k, v in self.aggregates.items()) + obj.aggregates = self.aggregates.copy() if self.aggregate_select_mask is None: obj.aggregate_select_mask = None else: @@ -559,9 +558,8 @@ class Query(object): new_col = change_map.get(col[0], col[0]), col[1] self.select.append(SelectInfo(new_col, field)) else: - item = col.clone() - item.relabel_aliases(change_map) - self.select.append(SelectInfo(item, field)) + new_col = col.relabeled_clone(change_map) + self.select.append(SelectInfo(new_col, field)) if connector == OR: # It would be nice to be able to handle this, but the queries don't @@ -769,26 +767,6 @@ class Query(object): The principle for promotion is: any alias which is used (it is in alias_usage_counts), is not used by every child of the ORed filter, and isn't pre-existing needs to be promoted to LOUTER join. - - Some examples (assume all joins used are nullable): - - existing filter: a__f1=foo - - add filter: b__f1=foo|b__f2=foo - In this case we should not promote either of the joins (using INNER - doesn't remove results). We correctly avoid join promotion, because - a is not used in this branch, and b is used two times. - - - add filter a__f1=foo|b__f2=foo - In this case we should promote both a and b, otherwise they will - remove results. We will also correctly do that as both aliases are - used, and in addition both are used only once while there are two - filters. - - - existing: a__f1=bar - - add filter: a__f2=foo|b__f2=foo - We will not promote a as it is previously used. If the join results - in null, the existing filter can't succeed. - - The above (and some more) are tested in queries.DisjunctionPromotionTests """ for alias, use_count in alias_usage_counts.items(): if use_count < num_childs and alias not in aliases_before: @@ -807,8 +785,7 @@ class Query(object): old_alias = col[0] return (change_map.get(old_alias, old_alias), col[1]) else: - col.relabel_aliases(change_map) - return col + return col.relabeled_clone(change_map) # 1. Update references in "select" (normal columns plus aliases), # "group by", "where" and "having". self.where.relabel_aliases(change_map) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ef856893b5..152a396785 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -34,9 +34,15 @@ class WhereNode(tree.Node): The class is tied to the Query class that created it (in order to create the correct SQL). - The children in this tree are usually either Q-like objects or lists of - [table_alias, field_name, db_type, lookup_type, value_annotation, params]. - However, a child could also be any class with as_sql() and relabel_aliases() methods. + A child is usually a tuple of: + (Constraint(alias, targetcol, field), lookup_type, value) + where value can be either raw Python value, or Query, ExpressionNode or + something else knowing how to turn itself into SQL. + + However, a child could also be any class with as_sql() and either + relabeled_clone() method or relabel_aliases() and clone() methods. The + second alternative should be used if the alias is not the only mutable + variable. """ default = AND @@ -255,30 +261,22 @@ class WhereNode(tree.Node): lhs = qn(name) return connection.ops.field_cast_sql(db_type) % lhs - def relabel_aliases(self, change_map, node=None): + def relabel_aliases(self, change_map): """ Relabels the alias values of any children. 'change_map' is a dictionary mapping old (current) alias values to the new values. """ - if not node: - node = self - for pos, child in enumerate(node.children): + for pos, child in enumerate(self.children): if hasattr(child, 'relabel_aliases'): + # For example another WhereNode child.relabel_aliases(change_map) - elif isinstance(child, tree.Node): - self.relabel_aliases(change_map, child) elif isinstance(child, (list, tuple)): - if isinstance(child[0], (list, tuple)): - elt = list(child[0]) - if elt[0] in change_map: - elt[0] = change_map[elt[0]] - node.children[pos] = (tuple(elt),) + child[1:] - else: - child[0].relabel_aliases(change_map) - - # Check if the query value also requires relabelling - if hasattr(child[3], 'relabel_aliases'): - child[3].relabel_aliases(change_map) + # tuple starting with Constraint + child = (child[0].relabeled_clone(change_map),) + child[1:] + if hasattr(child[3], 'relabeled_clone'): + child = (child[0], child[1], child[2]) + ( + child[3].relabeled_clone(change_map),) + self.children[pos] = child def clone(self): """ @@ -290,11 +288,10 @@ class WhereNode(tree.Node): clone = self.__class__._new_instance( children=[], connector=self.connector, negated=self.negated) for child in self.children: - if isinstance(child, tuple): - clone.children.append( - (child[0].clone(), child[1], child[2], child[3])) - else: + if hasattr(child, 'clone'): clone.children.append(child.clone()) + else: + clone.children.append(child) return clone class EmptyWhere(WhereNode): @@ -313,11 +310,6 @@ class EverythingNode(object): def as_sql(self, qn=None, connection=None): return '', [] - def relabel_aliases(self, change_map, node=None): - return - - def clone(self): - return self class NothingNode(object): """ @@ -326,11 +318,6 @@ class NothingNode(object): def as_sql(self, qn=None, connection=None): raise EmptyResultSet - def relabel_aliases(self, change_map, node=None): - return - - def clone(self): - return self class ExtraWhere(object): def __init__(self, sqls, params): @@ -341,8 +328,6 @@ class ExtraWhere(object): sqls = ["(%s)" % sql for sql in self.sqls] return " AND ".join(sqls), list(self.params or ()) - def clone(self): - return self class Constraint(object): """ @@ -405,12 +390,11 @@ class Constraint(object): return (self.alias, self.col, db_type), params - def relabel_aliases(self, change_map): - if self.alias in change_map: - self.alias = change_map[self.alias] - - def clone(self): - new = Empty() - new.__class__ = self.__class__ - new.alias, new.col, new.field = self.alias, self.col, self.field - return new + def relabeled_clone(self, change_map): + if self.alias not in change_map: + return self + else: + new = Empty() + new.__class__ = self.__class__ + new.alias, new.col, new.field = change_map[self.alias], self.col, self.field + return new diff --git a/tests/queries/models.py b/tests/queries/models.py index f7f643d585..6132544c2f 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -470,3 +470,8 @@ class Paragraph(models.Model): class Page(models.Model): text = models.TextField() + +class MyObject(models.Model): + parent = models.ForeignKey('self', null=True, blank=True, related_name='children') + data = models.CharField(max_length=100) + created_at = models.DateTimeField(auto_now_add=True) diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 2c17cb5e76..976a0ab05e 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -25,7 +25,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, JobResponsibilities, BaseA, Identifier, Program, Channel, Page, Paragraph, - Chapter, Book) + Chapter, Book, MyObject) class BaseQuerysetTest(TestCase): @@ -2661,3 +2661,17 @@ class ManyToManyExcludeTest(TestCase): self.assertNotIn(b1, q) self.assertIn(b2, q) self.assertIn(b3, q) + +class RelabelCloneTest(TestCase): + def test_ticket_19964(self): + my1 = MyObject.objects.create(data='foo') + my1.parent = my1 + my1.save() + my2 = MyObject.objects.create(data='bar', parent=my1) + parents = MyObject.objects.filter(parent=F('id')) + children = MyObject.objects.filter(parent__in=parents).exclude(parent=F('id')) + self.assertEqual(list(parents), [my1]) + # Evaluating the children query (which has parents as part of it) does + # not change results for the parents query. + self.assertEqual(list(children), [my2]) + self.assertEqual(list(parents), [my1])