diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f0e2318a74..3afd4ab367 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -44,7 +44,6 @@ class Query(object): self.connection = connection self.alias_map = {} # Maps alias to table name self.table_map = {} # Maps table names to list of aliases. - self.join_map = {} # Maps join_tuple to list of aliases. self.rev_join_map = {} # Reverse of join_map. self.quote_cache = {} self.default_cols = True @@ -128,8 +127,7 @@ class Query(object): obj.connection = self.connection obj.alias_map = copy.deepcopy(self.alias_map) obj.table_map = self.table_map.copy() - obj.join_map = copy.deepcopy(self.join_map) - obj.rev_join_map = copy.deepcopy(self.rev_join_map) + obj.rev_join_map = self.rev_join_map.copy() obj.quote_cache = {} obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering @@ -584,42 +582,65 @@ class Query(object): if self.alias_map[alias][ALIAS_NULLABLE]: self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER - def change_alias(self, old_alias, new_alias): + def change_aliases(self, change_map): """ - Changes old_alias to new_alias, relabelling any references to it in - select columns and the where clause. + Changes the aliases in change_map (which maps old-alias -> new-alias), + relabelling any references to them in select columns and the where + clause. """ - assert new_alias not in self.alias_map + assert set(change_map.keys()).intersection(set(change_map.values())) == set() # 1. Update references in "select" and "where". - change_map = {old_alias: new_alias} self.where.relabel_aliases(change_map) for pos, col in enumerate(self.select): if isinstance(col, (list, tuple)): - if col[0] == old_alias: - self.select[pos] = (new_alias, col[1]) + self.select[pos] = (change_map.get(old_alias, old_alias), col[1]) else: col.relabel_aliases(change_map) # 2. Rename the alias in the internal table/alias datastructures. - alias_data = self.alias_map[old_alias] - alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias - table_aliases = self.table_map[alias_data[ALIAS_TABLE]] - for pos, alias in enumerate(table_aliases): - if alias == old_alias: - table_aliases[pos] = new_alias - break - self.alias_map[new_alias] = alias_data - del self.alias_map[old_alias] - for pos, alias in enumerate(self.tables): - if alias == old_alias: - self.tables[pos] = new_alias - break + for old_alias, new_alias in change_map.items(): + alias_data = self.alias_map[old_alias] + alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias + self.rev_join_map[new_alias] = self.rev_join_map[old_alias] + del self.rev_join_map[old_alias] + table_aliases = self.table_map[alias_data[ALIAS_TABLE]] + for pos, alias in enumerate(table_aliases): + if alias == old_alias: + table_aliases[pos] = new_alias + break + self.alias_map[new_alias] = alias_data + del self.alias_map[old_alias] + for pos, alias in enumerate(self.tables): + if alias == old_alias: + self.tables[pos] = new_alias + break # 3. Update any joins that refer to the old alias. for data in self.alias_map.values(): - if data[ALIAS_JOIN][LHS_ALIAS] == old_alias: - data[ALIAS_JOIN][LHS_ALIAS] = new_alias + alias = data[ALIAS_JOIN][LHS_ALIAS] + if alias in change_map: + data[ALIAS_JOIN][LHS_ALIAS] = change_map[alias] + + def bump_prefix(self): + """ + Changes the alias prefix to the next letter in the alphabet and + relabels all the aliases. Even tables that previously had no alias will + get an alias after this call (it's mostly used for nested queries and + the outer query will already be using the non-aliased table name). + + Subclasses who create their own prefix should override this method to + produce a similar result (a new prefix and relabelled aliases). + """ + assert ord(self.alias_prefix) < ord('Z') + self.alias_prefix = chr(ord(self.alias_prefix) + 1) + change_map = {} + prefix = self.alias_prefix + for pos, alias in enumerate(self.tables): + new_alias = '%s%d' % (prefix, pos) + change_map[alias] = new_alias + self.tables[pos] = new_alias + self.change_aliases(change_map) def get_initial_alias(self): """ @@ -681,15 +702,12 @@ class Query(object): is_table = False t_ident = (lhs_table, table, lhs_col, col) if not always_create: - aliases = self.join_map.get(t_ident) - if aliases: - for alias in aliases: - if alias not in exclusions: - self.ref_alias(alias) - if promote and self.alias_map[alias][ALIAS_NULLABLE]: - self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \ - self.LOUTER - return alias + for alias, row in self.rev_join_map.items(): + if t_ident == row and alias not in exclusions: + self.ref_alias(alias) + if promote and self.alias_map[alias][ALIAS_NULLABLE]: + self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER + return alias # If we get to here (no non-excluded alias exists), we'll fall # through to creating a new alias. @@ -708,7 +726,6 @@ class Query(object): join[JOIN_TYPE] = None self.alias_map[alias][ALIAS_JOIN] = join self.alias_map[alias][ALIAS_NULLABLE] = nullable - self.join_map.setdefault(t_ident, []).append(alias) self.rev_join_map[alias] = t_ident return alias diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 5fb42de85b..66f4bfdd8a 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -158,8 +158,7 @@ class UpdateQuery(Query): # We need to use a sub-select in the where clause to filter on things # from other tables. query = self.clone(klass=Query) - alias = '%s0' % self.alias_prefix - query.change_alias(query.tables[0], alias) + query.bump_prefix() self.add_fields([query.model._meta.pk.name]) # Now we adjust the current query: reset the where clause and get rid