From 2f8f588df1f200f14d8c93414aad1411d0624315 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Fri, 29 Feb 2008 15:53:25 +0000 Subject: [PATCH] queryset-refactor: Made update() work with inherited models. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7179 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/query.py | 75 +++++++++++- django/db/models/sql/subqueries.py | 119 ++++++++++++++++--- tests/modeltests/model_inheritance/models.py | 8 ++ 3 files changed, 179 insertions(+), 23 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c4494a1b31..f0fb3ada13 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -503,7 +503,7 @@ class Query(object): name, order = get_order_dir(name, default_order) pieces = name.split(LOOKUP_SEP) if not alias: - alias = self.join((None, opts.db_table, None, None)) + alias = self.get_initial_alias() field, target, opts, joins = self.setup_joins(pieces, opts, alias, False) alias = joins[-1][-1] @@ -581,6 +581,62 @@ class Query(object): return True return False + def change_alias(self, old_alias, new_alias): + """ + Changes old_alias to new_alias, relabelling any references to it in + select columns and the where clause. + """ + assert new_alias not in self.alias_map + + # 1. Update references in "select" and "where". + change_map = {old_alias: new_alias} + self.where.relabel_aliases(change_map) + for pos, col in enumerate(self.select): + if isinstance(col, (list, tuple)): + if col[0] == old_alias: + self.select[pos] = (new_alias, col[1]) + else: + col.relabel_aliases(change_map) + + # 2. Rename the alias in the internal table/alias datastructures. + alias_data = self.alias_map[old_alias] + alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias + table_aliases = self.table_map[alias_data[ALIAS_TABLE]] + for pos, alias in enumerate(table_aliases): + if alias == old_alias: + table_aliases[pos] = new_alias + break + self.alias_map[new_alias] = alias_data + del self.alias_map[old_alias] + for pos, alias in enumerate(self.tables): + if alias == old_alias: + self.tables[pos] = new_alias + break + + # 3. Update any joins that refer to the old alias. + for data in self.alias_map.values(): + if data[ALIAS_JOIN][LHS_ALIAS] == old_alias: + data[ALIAS_JOIN][LHS_ALIAS] = new_alias + + def get_initial_alias(self): + """ + Returns the first alias for this query, after increasing its reference + count. + """ + if self.tables: + alias = self.tables[0] + self.ref_alias(alias) + else: + alias = self.join((None, self.model._meta.db_table, None, None)) + return alias + + def count_active_tables(self): + """ + Returns the number of tables in this query with a non-zero reference + count. + """ + return len([1 for o in self.alias_map.values() if o[ALIAS_REFCOUNT]]) + def join(self, connection, always_create=False, exclusions=(), promote=False, outer_if_first=False, nullable=False): """ @@ -728,7 +784,7 @@ class Query(object): value = value() opts = self.get_meta() - alias = self.join((None, opts.db_table, None, None)) + alias = self.get_initial_alias() allow_many = trim or not negate result = self.setup_joins(parts, opts, alias, (connector == AND), @@ -1021,8 +1077,12 @@ class Query(object): Adds the given column names to the select set, assuming they come from the root model (the one given in self.model). """ - table = self.model._meta.db_table - self.select.extend([(table, col) for col in columns]) + for alias in self.tables: + if self.alias_map[alias][ALIAS_REFCOUNT]: + break + else: + alias = self.get_initial_alias() + self.select.extend([(alias, col) for col in columns]) def add_ordering(self, *ordering): """ @@ -1111,7 +1171,7 @@ class Query(object): against the join table of many-to-many relation in a subquery. """ opts = self.model._meta - alias = self.join((None, opts.db_table, None, None)) + alias = self.get_initial_alias() field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP), opts, alias, False) alias = joins[-1][0] @@ -1141,6 +1201,8 @@ class Query(object): """ try: sql, params = self.as_sql() + if not sql: + raise EmptyResultSet except EmptyResultSet: if result_type == MULTI: raise StopIteration @@ -1173,6 +1235,9 @@ def get_order_dir(field, default='ASC'): return field, dirn[0] def results_iter(cursor): + """ + An iterator over the result set that returns a chunk of rows at a time. + """ while 1: rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) if not rows: diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index d4874d3d23..9c9952d16b 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -1,6 +1,8 @@ """ Query subclasses which provide extra functionality beyond simple data retrieval. """ +from copy import deepcopy + from django.contrib.contenttypes import generic from django.core.exceptions import FieldError from django.db.models.sql.constants import * @@ -94,31 +96,32 @@ class UpdateQuery(Query): def _setup_query(self): """ - Run on initialisation and after cloning. + Runs on initialisation and after cloning. Any attributes that would + normally be set in __init__ should go in here, instead, so that they + are also set up after a clone() call. """ self.values = [] + self.related_updates = {} + self.related_ids = None + + def clone(self, klass=None, **kwargs): + return super(UpdateQuery, self).clone(klass, + related_updates=self.related_updates.copy, **kwargs) + + def execute_sql(self, result_type=None): + super(UpdateQuery, self).execute_sql(result_type) + for query in self.get_related_updates(): + query.execute_sql(result_type) def as_sql(self): """ Creates the SQL for this query. Returns the SQL string and list of parameters. """ - self.select_related = False self.pre_sql_setup() - - if len(self.tables) != 1: - # We can only update one table at a time, so we need to check that - # only one alias has a nonzero refcount. - table = None - for alias_list in self.table_map.values(): - for alias in alias_list: - if self.alias_map[alias][ALIAS_REFCOUNT]: - if table: - raise FieldError('Updates can only access a single database table at a time.') - table = alias - else: - table = self.tables[0] - + if not self.values: + return '', () + table = self.tables[0] qn = self.quote_name_unless_alias result = ['UPDATE %s' % qn(table)] result.append('SET') @@ -135,6 +138,55 @@ class UpdateQuery(Query): result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) + def pre_sql_setup(self): + """ + If the update depends on results from other tables, we need to do some + munging of the "where" conditions to match the format required for + (portable) SQL updates. That is done here. + + Further, if we are going to be running multiple updates, we pull out + the id values to update at this point so that they don't change as a + result of the progressive updates. + """ + self.select_related = False + self.clear_ordering(True) + super(UpdateQuery, self).pre_sql_setup() + count = self.count_active_tables() + if not self.related_updates and count == 1: + return + + # We need to use a sub-select in the where clause to filter on things + # from other tables. + query = self.clone(klass=Query) + main_alias = query.tables[0] + if count != 1: + query.unref_alias(main_alias) + if query.alias_map[main_alias][ALIAS_REFCOUNT]: + alias = '%s0' % self.alias_prefix + query.change_alias(main_alias, alias) + col = query.model._meta.pk.column + else: + for model in query.model._meta.get_parent_list(): + for alias in query.table_map.get(model._meta.db_table, []): + if query.alias_map[alias][ALIAS_REFCOUNT]: + col = model._meta.pk.column + break + query.add_local_columns([col]) + + # Now we adjust the current query: reset the where clause and get rid + # of all the tables we don't need (since they're in the sub-select). + self.where = self.where_class() + if self.related_updates: + idents = [] + for rows in query.execute_sql(MULTI): + idents.extend([r[0] for r in rows]) + self.add_filter(('pk__in', idents)) + self.related_ids = idents + else: + self.add_filter(('pk__in', query)) + for alias in self.tables[1:]: + self.alias_map[alias][ALIAS_REFCOUNT] = 0 + def clear_related(self, related_field, pk_list): """ Set up and execute an update query that clears related entries for the @@ -156,11 +208,42 @@ class UpdateQuery(Query): for name, val in values.items(): field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: - # Can only update non-relation fields and foreign keys. raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) + # FIXME: Some sort of db_prep_* is probably more appropriate here. if field.rel and isinstance(val, Model): val = val.pk - self.values.append((field.column, val)) + if model: + self.add_related_update(model, field.column, val) + else: + self.values.append((field.column, val)) + + def add_related_update(self, model, column, value): + """ + Adds (name, value) to an update query for an ancestor model. + + Updates are coalesced so that we only run one update query per ancestor. + """ + try: + self.related_updates[model].append((column, value)) + except KeyError: + self.related_updates[model] = [(column, value)] + + def get_related_updates(self): + """ + Returns a list of query objects: one for each update required to an + ancestor model. Each query will have the same filtering conditions as + the current query but will only update a single table. + """ + if not self.related_updates: + return [] + result = [] + for model, values in self.related_updates.items(): + query = UpdateQuery(model, self.connection) + query.values = values + if self.related_ids: + query.add_filter(('pk__in', self.related_ids)) + result.append(query) + return result class InsertQuery(Query): def __init__(self, *args, **kwargs): diff --git a/tests/modeltests/model_inheritance/models.py b/tests/modeltests/model_inheritance/models.py index fb4aad4016..988edefa9c 100644 --- a/tests/modeltests/model_inheritance/models.py +++ b/tests/modeltests/model_inheritance/models.py @@ -216,5 +216,13 @@ DoesNotExist: Restaurant matching query does not exist. >>> Restaurant.objects.get(lot__name='Well Lit') +# The update() command can update fields in parent and child classes at once +# (although it executed multiple SQL queries to do so). +>>> Restaurant.objects.filter(serves_hot_dogs=True, name__contains='D').update(name='Demon Puppies', serves_hot_dogs=False) +>>> r1 = Restaurant.objects.get(pk=r.pk) +>>> r1.serves_hot_dogs == False +True +>>> r1.name +u'Demon Puppies' """}