1
0
mirror of https://github.com/django/django.git synced 2025-07-05 18:29:11 +00:00

queryset-refactor: Made update() work with inherited models.

git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7179 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-02-29 15:53:25 +00:00
parent fa11a6a128
commit 2f8f588df1
3 changed files with 179 additions and 23 deletions

View File

@ -503,7 +503,7 @@ class Query(object):
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
if not alias:
alias = self.join((None, opts.db_table, None, None))
alias = self.get_initial_alias()
field, target, opts, joins = self.setup_joins(pieces, opts, alias,
False)
alias = joins[-1][-1]
@ -581,6 +581,62 @@ class Query(object):
return True
return False
def change_alias(self, old_alias, new_alias):
"""
Changes old_alias to new_alias, relabelling any references to it in
select columns and the where clause.
"""
assert new_alias not in self.alias_map
# 1. Update references in "select" and "where".
change_map = {old_alias: new_alias}
self.where.relabel_aliases(change_map)
for pos, col in enumerate(self.select):
if isinstance(col, (list, tuple)):
if col[0] == old_alias:
self.select[pos] = (new_alias, col[1])
else:
col.relabel_aliases(change_map)
# 2. Rename the alias in the internal table/alias datastructures.
alias_data = self.alias_map[old_alias]
alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias
table_aliases = self.table_map[alias_data[ALIAS_TABLE]]
for pos, alias in enumerate(table_aliases):
if alias == old_alias:
table_aliases[pos] = new_alias
break
self.alias_map[new_alias] = alias_data
del self.alias_map[old_alias]
for pos, alias in enumerate(self.tables):
if alias == old_alias:
self.tables[pos] = new_alias
break
# 3. Update any joins that refer to the old alias.
for data in self.alias_map.values():
if data[ALIAS_JOIN][LHS_ALIAS] == old_alias:
data[ALIAS_JOIN][LHS_ALIAS] = new_alias
def get_initial_alias(self):
"""
Returns the first alias for this query, after increasing its reference
count.
"""
if self.tables:
alias = self.tables[0]
self.ref_alias(alias)
else:
alias = self.join((None, self.model._meta.db_table, None, None))
return alias
def count_active_tables(self):
"""
Returns the number of tables in this query with a non-zero reference
count.
"""
return len([1 for o in self.alias_map.values() if o[ALIAS_REFCOUNT]])
def join(self, connection, always_create=False, exclusions=(),
promote=False, outer_if_first=False, nullable=False):
"""
@ -728,7 +784,7 @@ class Query(object):
value = value()
opts = self.get_meta()
alias = self.join((None, opts.db_table, None, None))
alias = self.get_initial_alias()
allow_many = trim or not negate
result = self.setup_joins(parts, opts, alias, (connector == AND),
@ -1021,8 +1077,12 @@ class Query(object):
Adds the given column names to the select set, assuming they come from
the root model (the one given in self.model).
"""
table = self.model._meta.db_table
self.select.extend([(table, col) for col in columns])
for alias in self.tables:
if self.alias_map[alias][ALIAS_REFCOUNT]:
break
else:
alias = self.get_initial_alias()
self.select.extend([(alias, col) for col in columns])
def add_ordering(self, *ordering):
"""
@ -1111,7 +1171,7 @@ class Query(object):
against the join table of many-to-many relation in a subquery.
"""
opts = self.model._meta
alias = self.join((None, opts.db_table, None, None))
alias = self.get_initial_alias()
field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP),
opts, alias, False)
alias = joins[-1][0]
@ -1141,6 +1201,8 @@ class Query(object):
"""
try:
sql, params = self.as_sql()
if not sql:
raise EmptyResultSet
except EmptyResultSet:
if result_type == MULTI:
raise StopIteration
@ -1173,6 +1235,9 @@ def get_order_dir(field, default='ASC'):
return field, dirn[0]
def results_iter(cursor):
"""
An iterator over the result set that returns a chunk of rows at a time.
"""
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:

View File

@ -1,6 +1,8 @@
"""
Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from copy import deepcopy
from django.contrib.contenttypes import generic
from django.core.exceptions import FieldError
from django.db.models.sql.constants import *
@ -94,31 +96,32 @@ class UpdateQuery(Query):
def _setup_query(self):
"""
Run on initialisation and after cloning.
Runs on initialisation and after cloning. Any attributes that would
normally be set in __init__ should go in here, instead, so that they
are also set up after a clone() call.
"""
self.values = []
self.related_updates = {}
self.related_ids = None
def clone(self, klass=None, **kwargs):
return super(UpdateQuery, self).clone(klass,
related_updates=self.related_updates.copy, **kwargs)
def execute_sql(self, result_type=None):
super(UpdateQuery, self).execute_sql(result_type)
for query in self.get_related_updates():
query.execute_sql(result_type)
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
self.select_related = False
self.pre_sql_setup()
if len(self.tables) != 1:
# We can only update one table at a time, so we need to check that
# only one alias has a nonzero refcount.
table = None
for alias_list in self.table_map.values():
for alias in alias_list:
if self.alias_map[alias][ALIAS_REFCOUNT]:
if table:
raise FieldError('Updates can only access a single database table at a time.')
table = alias
else:
table = self.tables[0]
if not self.values:
return '', ()
table = self.tables[0]
qn = self.quote_name_unless_alias
result = ['UPDATE %s' % qn(table)]
result.append('SET')
@ -135,6 +138,55 @@ class UpdateQuery(Query):
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
def pre_sql_setup(self):
"""
If the update depends on results from other tables, we need to do some
munging of the "where" conditions to match the format required for
(portable) SQL updates. That is done here.
Further, if we are going to be running multiple updates, we pull out
the id values to update at this point so that they don't change as a
result of the progressive updates.
"""
self.select_related = False
self.clear_ordering(True)
super(UpdateQuery, self).pre_sql_setup()
count = self.count_active_tables()
if not self.related_updates and count == 1:
return
# We need to use a sub-select in the where clause to filter on things
# from other tables.
query = self.clone(klass=Query)
main_alias = query.tables[0]
if count != 1:
query.unref_alias(main_alias)
if query.alias_map[main_alias][ALIAS_REFCOUNT]:
alias = '%s0' % self.alias_prefix
query.change_alias(main_alias, alias)
col = query.model._meta.pk.column
else:
for model in query.model._meta.get_parent_list():
for alias in query.table_map.get(model._meta.db_table, []):
if query.alias_map[alias][ALIAS_REFCOUNT]:
col = model._meta.pk.column
break
query.add_local_columns([col])
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
self.where = self.where_class()
if self.related_updates:
idents = []
for rows in query.execute_sql(MULTI):
idents.extend([r[0] for r in rows])
self.add_filter(('pk__in', idents))
self.related_ids = idents
else:
self.add_filter(('pk__in', query))
for alias in self.tables[1:]:
self.alias_map[alias][ALIAS_REFCOUNT] = 0
def clear_related(self, related_field, pk_list):
"""
Set up and execute an update query that clears related entries for the
@ -156,11 +208,42 @@ class UpdateQuery(Query):
for name, val in values.items():
field, model, direct, m2m = self.model._meta.get_field_by_name(name)
if not direct or m2m:
# Can only update non-relation fields and foreign keys.
raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
# FIXME: Some sort of db_prep_* is probably more appropriate here.
if field.rel and isinstance(val, Model):
val = val.pk
self.values.append((field.column, val))
if model:
self.add_related_update(model, field.column, val)
else:
self.values.append((field.column, val))
def add_related_update(self, model, column, value):
"""
Adds (name, value) to an update query for an ancestor model.
Updates are coalesced so that we only run one update query per ancestor.
"""
try:
self.related_updates[model].append((column, value))
except KeyError:
self.related_updates[model] = [(column, value)]
def get_related_updates(self):
"""
Returns a list of query objects: one for each update required to an
ancestor model. Each query will have the same filtering conditions as
the current query but will only update a single table.
"""
if not self.related_updates:
return []
result = []
for model, values in self.related_updates.items():
query = UpdateQuery(model, self.connection)
query.values = values
if self.related_ids:
query.add_filter(('pk__in', self.related_ids))
result.append(query)
return result
class InsertQuery(Query):
def __init__(self, *args, **kwargs):

View File

@ -216,5 +216,13 @@ DoesNotExist: Restaurant matching query does not exist.
>>> Restaurant.objects.get(lot__name='Well Lit')
<Restaurant: Ristorante Miron the restaurant>
# The update() command can update fields in parent and child classes at once
# (although it executed multiple SQL queries to do so).
>>> Restaurant.objects.filter(serves_hot_dogs=True, name__contains='D').update(name='Demon Puppies', serves_hot_dogs=False)
>>> r1 = Restaurant.objects.get(pk=r.pk)
>>> r1.serves_hot_dogs == False
True
>>> r1.name
u'Demon Puppies'
"""}